Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions twisted/protocols/socks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,62 @@


class SOCKSv4Outgoing(protocol.Protocol):

def __init__(self,socks):
def __init__(self, socks):
self.socks=socks


def connectionMade(self):
peer = self.transport.getPeer()
self.socks.makeReply(90, 0, port=peer.port, ip=peer.host)
self.socks.otherConn=self


def connectionLost(self, reason):
self.socks.transport.loseConnection()

def dataReceived(self,data):

def dataReceived(self, data):
self.socks.write(data)


def write(self,data):
self.socks.log(self,data)
self.transport.write(data)



class SOCKSv4Incoming(protocol.Protocol):

def __init__(self,socks):
self.socks=socks
self.socks.otherConn=self


def connectionLost(self, reason):
self.socks.transport.loseConnection()


def dataReceived(self,data):
self.socks.write(data)

def write(self,data):

def write(self, data):
self.socks.log(self,data)
self.transport.write(data)



class SOCKSv4(protocol.Protocol):
"""
An implementation of the SOCKSv4 protocol.

@type logging: C{str} or L{None}
@type logging: L{str} or L{None}
@ivar logging: If not L{None}, the name of the logfile to which connection
information will be written.

@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
@ivar reactor: The reactor used to create connections.

@type buf: C{str}
@type buf: L{str}
@ivar buf: Part of a SOCKSv4 connection request.

@type otherConn: C{SOCKSv4Incoming}, C{SOCKSv4Outgoing} or L{None}
Expand All @@ -79,35 +85,37 @@ def __init__(self, logging=None, reactor=reactor):
self.logging = logging
self.reactor = reactor


def connectionMade(self):
self.buf = ""
self.buf = b""
self.otherConn = None


def dataReceived(self, data):
"""
Called whenever data is received.

@type data: C{str}
@type data: L{bytes}
@param data: Part or all of a SOCKSv4 packet.
"""
if self.otherConn:
self.otherConn.write(data)
return
self.buf = self.buf + data
completeBuffer = self.buf
if "\000" in self.buf[8:]:
if b"\000" in self.buf[8:]:
head, self.buf = self.buf[:8], self.buf[8:]
version, code, port = struct.unpack("!BBH", head[:4])
user, self.buf = self.buf.split("\000", 1)
if head[4:7] == "\000\000\000" and head[7] != "\000":
user, self.buf = self.buf.split(b"\000", 1)
if head[4:7] == b"\000\000\000" and head[7:8] != b"\000":
# An IP address of the form 0.0.0.X, where X is non-zero,
# signifies that this is a SOCKSv4a packet.
# If the complete packet hasn't been received, restore the
# buffer and wait for it.
if "\000" not in self.buf:
if b"\000" not in self.buf:
self.buf = completeBuffer
return
server, self.buf = self.buf.split("\000", 1)
server, self.buf = self.buf.split(b"\000", 1)
d = self.reactor.resolve(server)
d.addCallback(self._dataReceived2, user,
version, code, port)
Expand All @@ -118,27 +126,28 @@ def dataReceived(self, data):

self._dataReceived2(server, user, version, code, port)


def _dataReceived2(self, server, user, version, code, port):
"""
The second half of the SOCKS connection setup. For a SOCKSv4 packet this
is after the server address has been extracted from the header. For a
SOCKSv4a packet this is after the host name has been resolved.

@type server: C{str}
@type server: L{str}
@param server: The IP address of the destination, represented as a
dotted quad.

@type user: C{str}
@type user: L{str}
@param user: The username associated with the connection.

@type version: C{int}
@type version: L{int}
@param version: The SOCKS protocol version number.

@type code: C{int}
@type code: L{int}
@param code: The comand code. 1 means establish a TCP/IP stream
connection, and 2 means establish a TCP/IP port binding.

@type port: C{int}
@type port: L{int}
@param port: The port number associated with the connection.
"""
assert version == 4, "Bad version code: %s" % version
Expand All @@ -154,32 +163,39 @@ def _dataReceived2(self, server, user, version, code, port):
self = self: self.makeReply(90, 0, x[1], x[0]))
else:
raise RuntimeError("Bad Connect Code: %s" % (code,))
assert self.buf == "", "hmm, still stuff in buffer... %s" % repr(
assert self.buf == b"", "hmm, still stuff in buffer... %s" % repr(
self.buf)


def connectionLost(self, reason):
if self.otherConn:
self.otherConn.transport.loseConnection()


def authorize(self,code,server,port,user):
log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
return 1


def connectClass(self, host, port, klass, *args):
return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)


def listenClass(self, port, klass, *args):
serv = reactor.listenTCP(port, klass(*args))
return defer.succeed(serv.getHost()[1:])


def makeReply(self,reply,version=0,port=0,ip="0.0.0.0"):
self.transport.write(struct.pack("!BBH",version,reply,port)+socket.inet_aton(ip))
if reply!=90: self.transport.loseConnection()


def write(self,data):
self.log(self,data)
self.transport.write(data)


def log(self,proto,data):
if not self.logging: return
peer = self.transport.getPeer()
Expand Down Expand Up @@ -208,10 +224,10 @@ class SOCKSv4Factory(protocol.Factory):

Constructor accepts one argument, a log file name.
"""

def __init__(self, log):
self.logging = log


def buildProtocol(self, addr):
return SOCKSv4(self.logging, reactor)

Expand All @@ -221,7 +237,6 @@ class SOCKSv4IncomingFactory(protocol.Factory):
"""
A utility class for building protocols for incoming connections.
"""

def __init__(self, socks, ip):
self.socks = socks
self.ip = ip
Expand Down
23 changes: 16 additions & 7 deletions twisted/protocols/test/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Otherwise, the pyOpenSSL dependency must be satisfied, so all these
# imports will work.
from OpenSSL.crypto import X509Type
from OpenSSL.SSL import (TLSv1_METHOD, Error, Context, ConnectionType,
from OpenSSL.SSL import (TLSv1_METHOD, TLSv1_1_METHOD, TLSv1_2_METHOD, Error, Context, ConnectionType,
WantReadError)
from twisted.internet.ssl import PrivateCertificate, optionsForClientTLS
from twisted.test.ssl_helpers import (ClientTLSContext, ServerTLSContext,
Expand Down Expand Up @@ -62,8 +62,9 @@ class HandshakeCallbackContextFactory:
# https://bugs.launchpad.net/pyopenssl/+bug/372832
SSL_CB_HANDSHAKE_DONE = 0x20

def __init__(self):
def __init__(self, method=TLSv1_METHOD):
self._finished = Deferred()
self._method = method


def factoryAndDeferred(cls):
Expand Down Expand Up @@ -93,7 +94,7 @@ def getContext(self):
Create and return an SSL context configured to use L{self._info} as the
info callback.
"""
context = Context(TLSv1_METHOD)
context = Context(self._method)
context.set_info_callback(self._info)
return context

Expand Down Expand Up @@ -704,22 +705,22 @@ def cbConnectionDone(ignored):
return connectionDeferred


def test_hugeWrite(self):
def hugeWrite(self, method=TLSv1_METHOD):
"""
If a very long string is passed to L{TLSMemoryBIOProtocol.write}, any
trailing part of it which cannot be send immediately is buffered and
sent later.
"""
bytes = b"some bytes"
factor = 8192
factor = 2 ** 20
class SimpleSendingProtocol(Protocol):
def connectionMade(self):
self.transport.write(bytes * factor)

clientFactory = ClientFactory()
clientFactory.protocol = SimpleSendingProtocol

clientContextFactory = HandshakeCallbackContextFactory()
clientContextFactory = HandshakeCallbackContextFactory(method=method)
wrapperFactory = TLSMemoryBIOFactory(
clientContextFactory, True, clientFactory)
sslClientProtocol = wrapperFactory.buildProtocol(None)
Expand All @@ -728,7 +729,7 @@ def connectionMade(self):
serverFactory = ServerFactory()
serverFactory.protocol = lambda: serverProtocol

serverContextFactory = ServerTLSContext()
serverContextFactory = ServerTLSContext(method=method)
wrapperFactory = TLSMemoryBIOFactory(
serverContextFactory, False, serverFactory)
sslServerProtocol = wrapperFactory.buildProtocol(None)
Expand All @@ -742,6 +743,14 @@ def cbConnectionDone(ignored):
connectionDeferred.addCallback(cbConnectionDone)
return connectionDeferred

def test_hugeWrite_TLSv1(self):
return self.hugeWrite()

def test_hugeWrite_TLSv1_1(self):
return self.hugeWrite(method=TLSv1_1_METHOD)

def test_hugeWrite_TLSv1_2(self):
return self.hugeWrite(method=TLSv1_2_METHOD)

def test_disorderlyShutdown(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion twisted/protocols/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _write(self, bytes):
return

# A TLS payload is 16kB max
bufferSize = 2 ** 16
bufferSize = 2 ** 14

# How far into the input we've gotten so far
alreadySent = 0
Expand Down
2 changes: 2 additions & 0 deletions twisted/python/dist3.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
"twisted.protocols.haproxy.test.test_v2parser",
"twisted.protocols.haproxy.test.test_wrapper",
"twisted.protocols.haproxy.test.test_parser",
"twisted.protocols.socks",
"twisted.protocols.tls",
"twisted.python.__init__",
"twisted.python._appdirs",
Expand Down Expand Up @@ -494,6 +495,7 @@
"twisted.test.test_reflect",
"twisted.test.test_roots",
"twisted.test.test_sob",
"twisted.test.test_socks",
"twisted.test.test_ssl",
"twisted.test.test_sslverify",
"twisted.test.test_stdio",
Expand Down
5 changes: 3 additions & 2 deletions twisted/test/ssl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def getContext(self):
class ServerTLSContext:
isClient = 0

def __init__(self, filename=certPath):
def __init__(self, filename=certPath, method=SSL.TLSv1_METHOD):
self.filename = filename
self._method = method

def getContext(self):
ctx = SSL.Context(SSL.TLSv1_METHOD)
ctx = SSL.Context(self._method)
ctx.use_certificate_file(self.filename)
ctx.use_privatekey_file(self.filename)
return ctx
Loading