diff --git a/twisted/protocols/socks.py b/twisted/protocols/socks.py index 332f77385b7..a52c09b6697 100644 --- a/twisted/protocols/socks.py +++ b/twisted/protocols/socks.py @@ -18,21 +18,24 @@ class SOCKSv4Outgoing(protocol.Protocol): - - def __init__(self,socks): + def __init__(self, socks): self.socks=socks + def connectionMade(self): peer = self.transport.getPeer() self.socks.makeReply(90, 0, port=peer.port, ip=peer.host) self.socks.otherConn=self + def connectionLost(self, reason): self.socks.transport.loseConnection() - def dataReceived(self,data): + + def dataReceived(self, data): self.socks.write(data) + def write(self,data): self.socks.log(self,data) self.transport.write(data) @@ -40,34 +43,37 @@ def write(self,data): class SOCKSv4Incoming(protocol.Protocol): - def __init__(self,socks): self.socks=socks self.socks.otherConn=self + def connectionLost(self, reason): self.socks.transport.loseConnection() + def dataReceived(self,data): self.socks.write(data) - def write(self,data): + + def write(self, data): self.socks.log(self,data) self.transport.write(data) + class SOCKSv4(protocol.Protocol): """ An implementation of the SOCKSv4 protocol. - @type logging: C{str} or L{None} + @type logging: L{str} or L{None} @ivar logging: If not L{None}, the name of the logfile to which connection information will be written. @type reactor: object providing L{twisted.internet.interfaces.IReactorTCP} @ivar reactor: The reactor used to create connections. - @type buf: C{str} + @type buf: L{str} @ivar buf: Part of a SOCKSv4 connection request. @type otherConn: C{SOCKSv4Incoming}, C{SOCKSv4Outgoing} or L{None} @@ -79,15 +85,17 @@ def __init__(self, logging=None, reactor=reactor): self.logging = logging self.reactor = reactor + def connectionMade(self): - self.buf = "" + self.buf = b"" self.otherConn = None + def dataReceived(self, data): """ Called whenever data is received. - @type data: C{str} + @type data: L{bytes} @param data: Part or all of a SOCKSv4 packet. """ if self.otherConn: @@ -95,19 +103,19 @@ def dataReceived(self, data): return self.buf = self.buf + data completeBuffer = self.buf - if "\000" in self.buf[8:]: + if b"\000" in self.buf[8:]: head, self.buf = self.buf[:8], self.buf[8:] version, code, port = struct.unpack("!BBH", head[:4]) - user, self.buf = self.buf.split("\000", 1) - if head[4:7] == "\000\000\000" and head[7] != "\000": + user, self.buf = self.buf.split(b"\000", 1) + if head[4:7] == b"\000\000\000" and head[7:8] != b"\000": # An IP address of the form 0.0.0.X, where X is non-zero, # signifies that this is a SOCKSv4a packet. # If the complete packet hasn't been received, restore the # buffer and wait for it. - if "\000" not in self.buf: + if b"\000" not in self.buf: self.buf = completeBuffer return - server, self.buf = self.buf.split("\000", 1) + server, self.buf = self.buf.split(b"\000", 1) d = self.reactor.resolve(server) d.addCallback(self._dataReceived2, user, version, code, port) @@ -118,27 +126,28 @@ def dataReceived(self, data): self._dataReceived2(server, user, version, code, port) + def _dataReceived2(self, server, user, version, code, port): """ The second half of the SOCKS connection setup. For a SOCKSv4 packet this is after the server address has been extracted from the header. For a SOCKSv4a packet this is after the host name has been resolved. - @type server: C{str} + @type server: L{str} @param server: The IP address of the destination, represented as a dotted quad. - @type user: C{str} + @type user: L{str} @param user: The username associated with the connection. - @type version: C{int} + @type version: L{int} @param version: The SOCKS protocol version number. - @type code: C{int} + @type code: L{int} @param code: The comand code. 1 means establish a TCP/IP stream connection, and 2 means establish a TCP/IP port binding. - @type port: C{int} + @type port: L{int} @param port: The port number associated with the connection. """ assert version == 4, "Bad version code: %s" % version @@ -154,32 +163,39 @@ def _dataReceived2(self, server, user, version, code, port): self = self: self.makeReply(90, 0, x[1], x[0])) else: raise RuntimeError("Bad Connect Code: %s" % (code,)) - assert self.buf == "", "hmm, still stuff in buffer... %s" % repr( + assert self.buf == b"", "hmm, still stuff in buffer... %s" % repr( self.buf) + def connectionLost(self, reason): if self.otherConn: self.otherConn.transport.loseConnection() + def authorize(self,code,server,port,user): log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user)) return 1 + def connectClass(self, host, port, klass, *args): return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port) + def listenClass(self, port, klass, *args): serv = reactor.listenTCP(port, klass(*args)) return defer.succeed(serv.getHost()[1:]) + def makeReply(self,reply,version=0,port=0,ip="0.0.0.0"): self.transport.write(struct.pack("!BBH",version,reply,port)+socket.inet_aton(ip)) if reply!=90: self.transport.loseConnection() + def write(self,data): self.log(self,data) self.transport.write(data) + def log(self,proto,data): if not self.logging: return peer = self.transport.getPeer() @@ -208,10 +224,10 @@ class SOCKSv4Factory(protocol.Factory): Constructor accepts one argument, a log file name. """ - def __init__(self, log): self.logging = log + def buildProtocol(self, addr): return SOCKSv4(self.logging, reactor) @@ -221,7 +237,6 @@ class SOCKSv4IncomingFactory(protocol.Factory): """ A utility class for building protocols for incoming connections. """ - def __init__(self, socks, ip): self.socks = socks self.ip = ip diff --git a/twisted/protocols/test/test_tls.py b/twisted/protocols/test/test_tls.py index b6c95515f8e..29463743ea5 100644 --- a/twisted/protocols/test/test_tls.py +++ b/twisted/protocols/test/test_tls.py @@ -21,7 +21,7 @@ # Otherwise, the pyOpenSSL dependency must be satisfied, so all these # imports will work. from OpenSSL.crypto import X509Type - from OpenSSL.SSL import (TLSv1_METHOD, Error, Context, ConnectionType, + from OpenSSL.SSL import (TLSv1_METHOD, TLSv1_1_METHOD, TLSv1_2_METHOD, Error, Context, ConnectionType, WantReadError) from twisted.internet.ssl import PrivateCertificate, optionsForClientTLS from twisted.test.ssl_helpers import (ClientTLSContext, ServerTLSContext, @@ -62,8 +62,9 @@ class HandshakeCallbackContextFactory: # https://bugs.launchpad.net/pyopenssl/+bug/372832 SSL_CB_HANDSHAKE_DONE = 0x20 - def __init__(self): + def __init__(self, method=TLSv1_METHOD): self._finished = Deferred() + self._method = method def factoryAndDeferred(cls): @@ -93,7 +94,7 @@ def getContext(self): Create and return an SSL context configured to use L{self._info} as the info callback. """ - context = Context(TLSv1_METHOD) + context = Context(self._method) context.set_info_callback(self._info) return context @@ -704,14 +705,14 @@ def cbConnectionDone(ignored): return connectionDeferred - def test_hugeWrite(self): + def hugeWrite(self, method=TLSv1_METHOD): """ If a very long string is passed to L{TLSMemoryBIOProtocol.write}, any trailing part of it which cannot be send immediately is buffered and sent later. """ bytes = b"some bytes" - factor = 8192 + factor = 2 ** 20 class SimpleSendingProtocol(Protocol): def connectionMade(self): self.transport.write(bytes * factor) @@ -719,7 +720,7 @@ def connectionMade(self): clientFactory = ClientFactory() clientFactory.protocol = SimpleSendingProtocol - clientContextFactory = HandshakeCallbackContextFactory() + clientContextFactory = HandshakeCallbackContextFactory(method=method) wrapperFactory = TLSMemoryBIOFactory( clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) @@ -728,7 +729,7 @@ def connectionMade(self): serverFactory = ServerFactory() serverFactory.protocol = lambda: serverProtocol - serverContextFactory = ServerTLSContext() + serverContextFactory = ServerTLSContext(method=method) wrapperFactory = TLSMemoryBIOFactory( serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) @@ -742,6 +743,14 @@ def cbConnectionDone(ignored): connectionDeferred.addCallback(cbConnectionDone) return connectionDeferred + def test_hugeWrite_TLSv1(self): + return self.hugeWrite() + + def test_hugeWrite_TLSv1_1(self): + return self.hugeWrite(method=TLSv1_1_METHOD) + + def test_hugeWrite_TLSv1_2(self): + return self.hugeWrite(method=TLSv1_2_METHOD) def test_disorderlyShutdown(self): """ diff --git a/twisted/protocols/tls.py b/twisted/protocols/tls.py index 423b5fe0b2c..0bb9fdf6b3e 100644 --- a/twisted/protocols/tls.py +++ b/twisted/protocols/tls.py @@ -614,7 +614,7 @@ def _write(self, bytes): return # A TLS payload is 16kB max - bufferSize = 2 ** 16 + bufferSize = 2 ** 14 # How far into the input we've gotten so far alreadySent = 0 diff --git a/twisted/python/dist3.py b/twisted/python/dist3.py index ac52e89452e..d3e186b2336 100644 --- a/twisted/python/dist3.py +++ b/twisted/python/dist3.py @@ -224,6 +224,7 @@ "twisted.protocols.haproxy.test.test_v2parser", "twisted.protocols.haproxy.test.test_wrapper", "twisted.protocols.haproxy.test.test_parser", + "twisted.protocols.socks", "twisted.protocols.tls", "twisted.python.__init__", "twisted.python._appdirs", @@ -494,6 +495,7 @@ "twisted.test.test_reflect", "twisted.test.test_roots", "twisted.test.test_sob", + "twisted.test.test_socks", "twisted.test.test_ssl", "twisted.test.test_sslverify", "twisted.test.test_stdio", diff --git a/twisted/test/ssl_helpers.py b/twisted/test/ssl_helpers.py index 04a55d7b888..fdf27152fde 100644 --- a/twisted/test/ssl_helpers.py +++ b/twisted/test/ssl_helpers.py @@ -27,11 +27,12 @@ def getContext(self): class ServerTLSContext: isClient = 0 - def __init__(self, filename=certPath): + def __init__(self, filename=certPath, method=SSL.TLSv1_METHOD): self.filename = filename + self._method = method def getContext(self): - ctx = SSL.Context(SSL.TLSv1_METHOD) + ctx = SSL.Context(self._method) ctx.use_certificate_file(self.filename) ctx.use_privatekey_file(self.filename) return ctx diff --git a/twisted/test/test_socks.py b/twisted/test/test_socks.py index f4543d622f4..38283ef88a3 100644 --- a/twisted/test/test_socks.py +++ b/twisted/test/test_socks.py @@ -8,11 +8,12 @@ import struct, socket -from twisted.trial import unittest -from twisted.test import proto_helpers from twisted.internet import defer, address from twisted.internet.error import DNSLookupError +from twisted.python.compat import iterbytes from twisted.protocols import socks +from twisted.test import proto_helpers +from twisted.trial import unittest class StringTCPTransport(proto_helpers.StringTransport): @@ -22,9 +23,11 @@ class StringTCPTransport(proto_helpers.StringTransport): def getPeer(self): return self.peer + def getHost(self): return address.IPv4Address('TCP', '2.3.4.5', 42) + def loseConnection(self): self.stringTCPTransport_closing = True @@ -36,7 +39,7 @@ class FakeResolverReactor: """ def __init__(self, names): """ - @type names: C{dict} containing C{str} keys and C{str} values. + @type names: L{dict} containing L{str} keys and L{str} values. @param names: A hostname to IP address mapping. The IP addresses are stringified dotted quads. """ @@ -51,7 +54,8 @@ def resolve(self, hostname): return defer.succeed(self.names[hostname]) except KeyError: return defer.fail( - DNSLookupError("FakeResolverReactor couldn't find " + hostname)) + DNSLookupError("FakeResolverReactor couldn't find " + + hostname.decode("utf-8"))) @@ -71,6 +75,7 @@ def connectClass(self, host, port, klass, *args): self.driver_outgoing = proto return defer.succeed(proto) + def listenClass(self, port, klass, *args): # fake it factory = klass(*args) @@ -89,7 +94,7 @@ def setUp(self): self.sock = SOCKSv4Driver() self.sock.transport = StringTCPTransport() self.sock.connectionMade() - self.sock.reactor = FakeResolverReactor({"localhost":"127.0.0.1"}) + self.sock.reactor = FakeResolverReactor({b"localhost":"127.0.0.1"}) def tearDown(self): @@ -103,8 +108,8 @@ def test_simple(self): self.sock.dataReceived( struct.pack('!BBH', 4, 1, 34) + socket.inet_aton('1.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') sent = self.sock.transport.value() self.sock.transport.clear() self.assertEqual(sent, @@ -114,13 +119,13 @@ def test_simple(self): self.assertIsNotNone(self.sock.driver_outgoing) # pass some data through - self.sock.dataReceived('hello, world') + self.sock.dataReceived(b'hello, world') self.assertEqual(self.sock.driver_outgoing.transport.value(), - 'hello, world') + b'hello, world') # the other way around - self.sock.driver_outgoing.dataReceived('hi there') - self.assertEqual(self.sock.transport.value(), 'hi there') + self.sock.driver_outgoing.dataReceived(b'hi there') + self.assertEqual(self.sock.transport.value(), b'hi there') self.sock.connectionLost('fake reason') @@ -138,13 +143,13 @@ def test_socks4aSuccessfulResolution(self): clientRequest = ( struct.pack('!BBH', 4, 1, 34) + socket.inet_aton('0.0.0.1') - + 'fooBAZ\0' - + 'localhost\0') + + b'fooBAZ\0' + + b'localhost\0') # Deliver the bytes one by one to exercise the protocol's buffering # logic. FakeResolverReactor's resolve method is invoked to "resolve" # the hostname. - for byte in clientRequest: + for byte in iterbytes(clientRequest): self.sock.dataReceived(byte) sent = self.sock.transport.value() @@ -160,14 +165,14 @@ def test_socks4aSuccessfulResolution(self): # Pass some data through and verify it is forwarded to the outgoing # connection. - self.sock.dataReceived('hello, world') + self.sock.dataReceived(b'hello, world') self.assertEqual( - self.sock.driver_outgoing.transport.value(), 'hello, world') + self.sock.driver_outgoing.transport.value(), b'hello, world') # Deliver some data from the output connection and verify it is # passed along to the incoming side. - self.sock.driver_outgoing.dataReceived('hi there') - self.assertEqual(self.sock.transport.value(), 'hi there') + self.sock.driver_outgoing.dataReceived(b'hi there') + self.assertEqual(self.sock.transport.value(), b'hi there') self.sock.connectionLost('fake reason') @@ -181,13 +186,13 @@ def test_socks4aFailedResolution(self): clientRequest = ( struct.pack('!BBH', 4, 1, 34) + socket.inet_aton('0.0.0.1') - + 'fooBAZ\0' - + 'failinghost\0') + + b'fooBAZ\0' + + b'failinghost\0') # Deliver the bytes one by one to exercise the protocol's buffering # logic. FakeResolverReactor's resolve method is invoked to "resolve" # the hostname. - for byte in clientRequest: + for byte in iterbytes(clientRequest): self.sock.dataReceived(byte) # Verify that the server responds with a 91 error. @@ -206,8 +211,8 @@ def test_accessDenied(self): self.sock.dataReceived( struct.pack('!BBH', 4, 1, 4242) + socket.inet_aton('10.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') self.assertEqual(self.sock.transport.value(), struct.pack('!BBH', 0, 91, 0) + socket.inet_aton('0.0.0.0')) @@ -219,14 +224,14 @@ def test_eofRemote(self): self.sock.dataReceived( struct.pack('!BBH', 4, 1, 34) + socket.inet_aton('1.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') self.sock.transport.clear() # pass some data through - self.sock.dataReceived('hello, world') + self.sock.dataReceived(b'hello, world') self.assertEqual(self.sock.driver_outgoing.transport.value(), - 'hello, world') + b'hello, world') # now close it from the server side self.sock.driver_outgoing.transport.loseConnection() @@ -237,14 +242,14 @@ def test_eofLocal(self): self.sock.dataReceived( struct.pack('!BBH', 4, 1, 34) + socket.inet_aton('1.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') self.sock.transport.clear() # pass some data through - self.sock.dataReceived('hello, world') + self.sock.dataReceived(b'hello, world') self.assertEqual(self.sock.driver_outgoing.transport.value(), - 'hello, world') + b'hello, world') # now close it from the client side self.sock.connectionLost('fake reason') @@ -259,7 +264,7 @@ def setUp(self): self.sock = SOCKSv4Driver() self.sock.transport = StringTCPTransport() self.sock.connectionMade() - self.sock.reactor = FakeResolverReactor({"localhost":"127.0.0.1"}) + self.sock.reactor = FakeResolverReactor({b"localhost":"127.0.0.1"}) ## def tearDown(self): ## # TODO ensure the listen port is closed @@ -272,8 +277,8 @@ def test_simple(self): self.sock.dataReceived( struct.pack('!BBH', 4, 2, 34) + socket.inet_aton('1.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') sent = self.sock.transport.value() self.sock.transport.clear() self.assertEqual(sent, @@ -297,13 +302,13 @@ def test_simple(self): self.assertFalse(self.sock.transport.stringTCPTransport_closing) # pass some data through - self.sock.dataReceived('hello, world') + self.sock.dataReceived(b'hello, world') self.assertEqual(incoming.transport.value(), - 'hello, world') + b'hello, world') # the other way around - incoming.dataReceived('hi there') - self.assertEqual(self.sock.transport.value(), 'hi there') + incoming.dataReceived(b'hi there') + self.assertEqual(self.sock.transport.value(), b'hi there') self.sock.connectionLost('fake reason') @@ -321,13 +326,13 @@ def test_socks4a(self): clientRequest = ( struct.pack('!BBH', 4, 2, 34) + socket.inet_aton('0.0.0.1') - + 'fooBAZ\0' - + 'localhost\0') + + b'fooBAZ\0' + + b'localhost\0') # Deliver the bytes one by one to exercise the protocol's buffering # logic. FakeResolverReactor's resolve method is invoked to "resolve" # the hostname. - for byte in clientRequest: + for byte in iterbytes(clientRequest): self.sock.dataReceived(byte) sent = self.sock.transport.value() @@ -358,12 +363,12 @@ def test_socks4a(self): # Deliver some data from the output connection and verify it is # passed along to the incoming side. - self.sock.dataReceived('hi there') - self.assertEqual(incoming.transport.value(), 'hi there') + self.sock.dataReceived(b'hi there') + self.assertEqual(incoming.transport.value(), b'hi there') # the other way around - incoming.dataReceived('hi there') - self.assertEqual(self.sock.transport.value(), 'hi there') + incoming.dataReceived(b'hi there') + self.assertEqual(self.sock.transport.value(), b'hi there') self.sock.connectionLost('fake reason') @@ -377,13 +382,13 @@ def test_socks4aFailedResolution(self): clientRequest = ( struct.pack('!BBH', 4, 2, 34) + socket.inet_aton('0.0.0.1') - + 'fooBAZ\0' - + 'failinghost\0') + + b'fooBAZ\0' + + b'failinghost\0') # Deliver the bytes one by one to exercise the protocol's buffering # logic. FakeResolverReactor's resolve method is invoked to "resolve" # the hostname. - for byte in clientRequest: + for byte in iterbytes(clientRequest): self.sock.dataReceived(byte) # Verify that the server responds with a 91 error. @@ -402,20 +407,21 @@ def test_accessDenied(self): self.sock.dataReceived( struct.pack('!BBH', 4, 2, 4242) + socket.inet_aton('10.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') self.assertEqual(self.sock.transport.value(), struct.pack('!BBH', 0, 91, 0) + socket.inet_aton('0.0.0.0')) self.assertTrue(self.sock.transport.stringTCPTransport_closing) self.assertIsNone(self.sock.driver_listen) + def test_eofRemote(self): self.sock.dataReceived( struct.pack('!BBH', 4, 2, 34) + socket.inet_aton('1.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') sent = self.sock.transport.value() self.sock.transport.clear() @@ -434,20 +440,21 @@ def test_eofRemote(self): self.assertFalse(self.sock.transport.stringTCPTransport_closing) # pass some data through - self.sock.dataReceived('hello, world') + self.sock.dataReceived(b'hello, world') self.assertEqual(incoming.transport.value(), - 'hello, world') + b'hello, world') # now close it from the server side incoming.transport.loseConnection() incoming.connectionLost('fake reason') + def test_eofLocal(self): self.sock.dataReceived( struct.pack('!BBH', 4, 2, 34) + socket.inet_aton('1.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') sent = self.sock.transport.value() self.sock.transport.clear() @@ -466,19 +473,20 @@ def test_eofLocal(self): self.assertFalse(self.sock.transport.stringTCPTransport_closing) # pass some data through - self.sock.dataReceived('hello, world') + self.sock.dataReceived(b'hello, world') self.assertEqual(incoming.transport.value(), - 'hello, world') + b'hello, world') # now close it from the client side self.sock.connectionLost('fake reason') + def test_badSource(self): self.sock.dataReceived( struct.pack('!BBH', 4, 2, 34) + socket.inet_aton('1.2.3.4') - + 'fooBAR' - + '\0') + + b'fooBAR' + + b'\0') sent = self.sock.transport.value() self.sock.transport.clear() diff --git a/twisted/topfiles/8665.feature b/twisted/topfiles/8665.feature new file mode 100644 index 00000000000..ae33200b114 --- /dev/null +++ b/twisted/topfiles/8665.feature @@ -0,0 +1 @@ +twisted.protocols.socks has been ported to Python 3 diff --git a/twisted/topfiles/8693.bugfix b/twisted/topfiles/8693.bugfix new file mode 100644 index 00000000000..7db368722c5 --- /dev/null +++ b/twisted/topfiles/8693.bugfix @@ -0,0 +1 @@ +Reduced buffersize to 2^14 bytes in _write method of twisted.protocols.tls so as to avoid exceeding the maximum TLS payload.