diff --git a/README.rst b/README.rst index 423bd98..b7e345f 100644 --- a/README.rst +++ b/README.rst @@ -88,6 +88,7 @@ You can use the following options: - ``port``: the numeric port number to use (default to ``80``) - ``address``: the IP-address to bind to (default to ``''``) +- ``servername``: server header field used in responses (default to ``None``) Example configuration (in YAML):: @@ -106,7 +107,11 @@ of stores, ``memory`` and ``redis``. Each of these stores has specific options. - ``key_prefix``: a string prepended to a channel identifier to make a redis key. Use this to avoid key collision when you're using your redis server for other stuff. -Memory stores haven't any specific options (yet). +For memory stores: + +- ``min_messages``: the minimum number of messages to store per channel +- ``max_messages``: the maximum number of messages to store per channel +- ``message_timeout``: the length of time a message may be queued before it is expired Here is an example of how to specify the store (YAML):: @@ -161,6 +166,10 @@ A location has a ``type`` of either ``publisher`` or ``subscriber``. It supports - ``url``: the complete URL pattern to use for this location, eg: ``/channel/(\d+)/publish/``. Not you should have only one capture group, that must represent the channel id. This settings has precedence over ``prefix`` (not set by default) - ``polling`` (subscriber only): ``interval`` or ``long``, see the protocol_ for more information (default to ``long``) - ``create_on_post`` (publisher only): if set to ``false``, you will need to create a channel with a PUT request first before POSTing any data to it (default to ``true``) +- ``create_on_get`` (subscriber only): if set to ``true``, a non-existing channel will be automatically created at the first GET request (default to ``false``) +- ``allow_origin`` (subscriber only): value of ``Access-Control-Allow-Origin`` header send as defined by Cross-Origin Resource Sharing specification (default to ``*``) +- ``allow_credentials`` (subscriber only): value of ``Access-Control-Allow-Credentials`` header send as defined by Cross-Origin Resource Sharing specification (default to ``False``); cannot be ``True`` if ``allow_origin`` is set to ``*`` +- ``passthrough`` (subscriber only): if set to an URL, client's request headers will be passthrough to the given URL every time client subscribes or unsubscribes (default to ``None``) For info, the default configuration looks like this:: @@ -233,21 +242,13 @@ Caveats Running Tests ------------- -Make sure you have a test redis server accessible at ``localhost:6379``. **Be careful**, the tests suite will -flush your server default database, you've been warned. +Make sure you have a test redis server accessible at ``localhost:6379``. **Be careful, the tests suite will +flush your server default database, you've been warned.** Run the test suite with :: $ python setup.py nosetests -Known Issues ------------- - -- hbpushd depends on the development version of facebook's tornado. ``setup.py`` will install a - compatible version, but if you have already installed tornado through ``easy_install`` or ``pip``, - you might have some problems with Etags, or when launching hbpushd. In that case, reinstall - the latest version of tornado_. - Change log ---------- diff --git a/bin/hbpushd b/bin/hbpushd index 3cf9420..3acc48c 100755 --- a/bin/hbpushd +++ b/bin/hbpushd @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/env python ## DEFAULT CONFIGURATION ## default_store= { @@ -9,13 +9,20 @@ default_store= { 'database': 0, }, 'memory': { + 'min_messages': 0, + 'max_messages': 0, + 'message_timeout': 0, } } default_location = { 'subscriber': { 'polling': 'long', + 'create_on_get': False, 'store': 'default', + 'allow_origin': '*', + 'allow_credentials': False, + 'passthrough': None, }, 'publisher': { 'create_on_post': True, @@ -26,6 +33,7 @@ default_location = { defaults = { 'port': 80, 'address': '', + 'servername': None, 'store': { 'type': 'memory', }, @@ -100,7 +108,10 @@ def make_stores(stores_dict): from hbpush.pubsub.publisher import Publisher from hbpush.pubsub.subscriber import Subscriber, LongPollingSubscriber -def make_location(loc_dict, stores={}): +def make_location(loc_dict, stores=None, servername=None): + if stores is None: + stores = {} + loc_conf = default_location.get(loc_dict['type'], {}).copy() loc_conf.update(loc_dict) @@ -118,16 +129,19 @@ def make_location(loc_dict, stores={}): else: raise InvalidConfigurationError('Invalid location type `%s`' % loc_type) - url = loc_conf.pop('url', loc_conf.pop('prefix')+'(.+)') + url = loc_conf.pop('url', loc_conf.pop('prefix', '')+'(.+)') store_id = loc_conf.pop('store') - kwargs = {'registry': stores[store_id]['registry']} + kwargs = { + 'registry': stores[store_id]['registry'], + 'servername': servername, + } kwargs.update(loc_conf) return (url, cls, kwargs) from functools import partial conf['store'] = make_stores(conf['store']) -conf['locations'] = map(partial(make_location, stores=conf['store']), conf['locations']) +conf['locations'] = map(partial(make_location, stores=conf['store'], servername=conf['servername']), conf['locations']) from tornado.web import Application from tornado.httpserver import HTTPServer diff --git a/hbpush/__init__.py b/hbpush/__init__.py index 3a83701..842ec28 100644 --- a/hbpush/__init__.py +++ b/hbpush/__init__.py @@ -1,2 +1,2 @@ -VERSION = (0, 1, 0) +VERSION = (0, 1, 4) __version__ = '.'.join(map(str, VERSION)) diff --git a/hbpush/channel.py b/hbpush/channel.py index aecf6d9..a713345 100644 --- a/hbpush/channel.py +++ b/hbpush/channel.py @@ -1,6 +1,9 @@ +from tornado import httpclient from hbpush.message import Message + import logging import time +import urllib class Channel(object): @@ -23,8 +26,7 @@ def __init__(self, id, store): # Empty message, we just want to keep etag and lastmodified data self.last_message = Message(0, -1) - def get_last_message(self): - return self.last_message + self.client = httpclient.AsyncHTTPClient() def send_to_subscribers(self, message): # We work on a copy to deal with reentering subscribers @@ -49,30 +51,47 @@ def _process_message(message): message = self.make_message(content_type, body) self.store.post(self.id, message, callback=_process_message, errback=errback) - def wait_for(self, last_modified, etag, id_subscriber, callback, errback): + def wait_for(self, last_modified, etag, request, passthrough, id_subscriber, callback, errback): request_msg = Message(last_modified, etag) def _cb(message): if request_msg >= message: - self.subscribe(id_subscriber, _cb, errback) + self.subscribe(id_subscriber, request, passthrough, _cb, errback) else: callback(message) - self.subscribe(id_subscriber, _cb, errback) + self.subscribe(id_subscriber, request, passthrough, _cb, errback) + + def _passthrough(self, action, request, passthrough): + if not passthrough or request.method != 'GET': + return + + def ignore(response): + pass - def subscribe(self, id_subscriber, callback, errback): + url = passthrough + body = urllib.urlencode({'channel_id': self.id, action: 1}) + self.client.fetch(url, ignore, method='POST', body=body, headers=request.headers) + + def subscribe(self, id_subscriber, request, passthrough, callback, errback): + self._passthrough('subscribe', request, passthrough) self.subscribers[id_subscriber] = (callback, errback) - def unsubscribe(self, id_subscriber): + def unsubscribe(self, id_subscriber, request, passthrough): + self._passthrough('unsubscribe', request, passthrough) self.subscribers.pop(id_subscriber, None) def get(self, last_modified, etag, callback, errback): request_msg = Message(last_modified, etag) if request_msg < self.last_message: - self.store.get(self.id, last_modified, etag, callback=callback, errback=errback) - else: - errback(Channel.NotModified()) + try: + self.store.get(self.id, last_modified, etag, callback=callback, errback=errback) + return + except Message.Expired: + pass + + errback(Channel.NotModified()) def delete(self, callback, errback): for id, (cb, eb) in self.subscribers.items(): @@ -85,7 +104,7 @@ def delete(self, callback, errback): def make_message(self, content_type, body): if not self.sentinel: - self.sentinel = self.get_last_message() + self.sentinel = self.last_message last_modified = int(time.time()) if last_modified == self.sentinel.last_modified: diff --git a/hbpush/message.py b/hbpush/message.py index 6011c00..03047e0 100644 --- a/hbpush/message.py +++ b/hbpush/message.py @@ -16,3 +16,6 @@ class DoesNotExist(Exception): class Invalid(Exception): pass + + class Expired(Exception): + pass diff --git a/hbpush/pubsub/__init__.py b/hbpush/pubsub/__init__.py index e6d6261..a2cce38 100644 --- a/hbpush/pubsub/__init__.py +++ b/hbpush/pubsub/__init__.py @@ -14,11 +14,27 @@ class PubSubHandler(RequestHandler): def __init__(self, *args, **kwargs): self.registry = kwargs.pop('registry', None) + self.servername = kwargs.pop('servername', None) + self.allow_origin = kwargs.pop('allow_origin', '*') + self.allow_credentials = kwargs.pop('allow_credentials', False) + if (self.allow_origin == '*' and self.allow_credentials): + raise AttributeError("allow_origin cannot be '*' with allow_credentials set to true") super(PubSubHandler, self).__init__(*args, **kwargs) def add_vary_header(self): self.set_header('Vary', 'If-Modified-Since, If-None-Match') + def add_accesscontrol_headers(self): + self.set_header('Access-Control-Allow-Origin', self.allow_origin) + self.set_header('Access-Control-Allow-Headers', 'If-Modified-Since, If-None-Match, X-Cookie') + self.set_header('Access-Control-Expose-Headers', 'Last-Modified, Etag, Cache-Control') + self.set_header('Access-Control-Allow-Credentials', 'true' if self.allow_credentials else 'false') + self.set_header('Access-Control-Max-Age', '864000') + + def set_default_headers(self): + if self.servername: + self.set_header('Server', self.servername) + def _handle_request_exception(self, e): if e.__class__ in self.exception_mapping: e = HTTPError(self.exception_mapping[e.__class__], str(e)) @@ -27,7 +43,6 @@ def _handle_request_exception(self, e): errback = _handle_request_exception - def simple_finish(self, *args, **kwargs): # ignore everything, and just finish the request self.finish() diff --git a/hbpush/pubsub/subscriber.py b/hbpush/pubsub/subscriber.py index 6bd0e25..9d2319c 100644 --- a/hbpush/pubsub/subscriber.py +++ b/hbpush/pubsub/subscriber.py @@ -4,24 +4,50 @@ from email.utils import formatdate, parsedate_tz, mktime_tz from functools import partial +import logging +import calendar + +# mktime_tz has some problems on Windows (http://bugs.python.org/issue14653), +# so we are converting manually +def convert_timestamp(timestamp): + t = parsedate_tz(timestamp) + if t[9] is None: + return mktime_tz(t) + else: + g = calendar.timegm(t[:9]) + return g - t[9] class Subscriber(PubSubHandler): + def __init__(self, *args, **kwargs): + self.create_on_get = kwargs.pop('create_on_get', False) + self.passthrough = kwargs.pop('passthrough', None) + super(Subscriber, self).__init__(*args, **kwargs) + @asynchronous def get(self, channel_id): try: etag = int(self.request.headers.get('If-None-Match', -1)) - last_modified = int('If-Modified-Since' in self.request.headers and mktime_tz(parsedate_tz(self.request.headers['If-Modified-Since'])) or 0) - except: + last_modified = int('If-Modified-Since' in self.request.headers and convert_timestamp(self.request.headers['If-Modified-Since']) or 0) + except Exception, e: + logging.warning('Error parsing request headers: %s', e) raise HTTPError(400) - self.registry.get(channel_id, + getattr(self.registry, 'get_or_create' if self.create_on_get else 'get')(channel_id, callback=self.async_callback(partial(self._process_channel, last_modified, etag)), errback=self.errback) + def options(self, channel_id): + self.add_accesscontrol_headers() + def _process_message(self, message): self.set_header('Etag', message.etag) + # Chrome and other WebKit-based browsers do not (yet) support Access-Control-Expose-Headers, + # but they allow access to Cache-Control so we use it to additionally store etag information there + # (This field is by standard extendable with custom tokens) + self.set_header('Cache-Control', '%s=%s' % ('etag', message.etag)) self.set_header('Last-Modified', formatdate(message.last_modified, localtime=False, usegmt=True)) self.add_vary_header() + self.add_accesscontrol_headers() self.set_header('Content-Type', message.content_type) self.write(message.body) self.finish() @@ -35,7 +61,7 @@ def _process_channel(self, last_modified, etag, channel): class LongPollingSubscriber(Subscriber): def unsubscribe(self): if hasattr(self, 'channel'): - self.channel.unsubscribe(id(self)) + self.channel.unsubscribe(id(self), self.request, self.passthrough) on_connection_close = unsubscribe def finish(self, chunk=None): @@ -46,7 +72,7 @@ def _process_channel(self, last_modified, etag, channel): @self.async_callback def _wait_for_message(error): if error.__class__ == Channel.NotModified: - self.channel.wait_for(last_modified, etag, id(self), callback=self.async_callback(self._process_message), errback=self.errback) + self.channel.wait_for(last_modified, etag, self.request, self.passthrough, id(self), callback=self.async_callback(self._process_message), errback=self.errback) else: self.errback(error) diff --git a/hbpush/store/__init__.py b/hbpush/store/__init__.py index cb63129..2c63f8f 100644 --- a/hbpush/store/__init__.py +++ b/hbpush/store/__init__.py @@ -1,4 +1,7 @@ class Store(object): + def __init__(self, *args, **kwargs): + pass + def get(self, channel_id, last_modified, etag, callback, errback): raise NotImplementedError("") diff --git a/hbpush/store/memory.py b/hbpush/store/memory.py index d0c301b..3a694ec 100644 --- a/hbpush/store/memory.py +++ b/hbpush/store/memory.py @@ -2,16 +2,47 @@ from hbpush.message import Message from bisect import bisect +import time class MemoryStore(Store): def __init__(self, *args, **kwargs): super(MemoryStore, self).__init__(*args, **kwargs) + self.min_messages = kwargs.pop('min_messages', 0) + self.max_messages = kwargs.pop('max_messages', 0) + self.message_timeout = kwargs.pop('message_timeout', 0) self.messages = {} + self.expired_channels = {} - def get(self, channel_id, last_modified, etag, callback, errback): + def _expire_messages(self, channel_id): channel_messages = self.messages.setdefault(channel_id, []) + if not channel_messages: + if self.expired_channels.get(channel_id, False): + raise Message.Expired() + else: + return channel_messages + + if self.max_messages and len(channel_messages) > self.max_messages: + channel_messages = channel_messages[-self.max_messages:] + + if self.message_timeout: + while channel_messages and len(channel_messages) > self.min_messages: + if channel_messages[0].last_modified + self.message_timeout >= int(time.time()): + break + channel_messages = channel_messages[1:] + + self.messages[channel_id] = channel_messages + + if not self.messages[channel_id]: + self.expired_channels[channel_id] = True + raise Message.Expired() + else: + return channel_messages + + def get(self, channel_id, last_modified, etag, callback, errback): + channel_messages = self._expire_messages(channel_id) + msg = Message(last_modified, etag) try: callback(channel_messages[bisect(channel_messages, msg)]) @@ -19,21 +50,25 @@ def get(self, channel_id, last_modified, etag, callback, errback): errback(Message.DoesNotExist()) def get_last(self, channel_id, callback, errback): - channel_messages = self.messages.setdefault(channel_id, []) + channel_messages = self._expire_messages(channel_id) - if len(channel_messages): + if channel_messages: callback(channel_messages[-1]) else: errback(Message.DoesNotExist()) def post(self, channel_id, message, callback, errback): self.messages.setdefault(channel_id, []).append(message) + self.expired_channels[channel_id] = False callback(message) def flush(self, channel_id, callback, errback): del self.messages[channel_id] + if channel_id in self.expired_channels: + del self.expired_channels[channel_id] callback(True) def flushall(self, callback, errback): self.messages = {} + self.expired_channels = {} callback(True) diff --git a/hbpush/store/redis.py b/hbpush/store/redis.py index 7eba1f9..c968694 100644 --- a/hbpush/store/redis.py +++ b/hbpush/store/redis.py @@ -9,7 +9,8 @@ class RedisStore(Store): - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): + super(RedisStore, self).__init__(*args, **kwargs) self.key_prefix = kwargs.pop('key_prefix', '') self.client = Client(**kwargs) self.client.connect() diff --git a/setup.py b/setup.py index ae306a7..7aa3b85 100644 --- a/setup.py +++ b/setup.py @@ -27,9 +27,8 @@ packages=('hbpush', 'hbpush.store', 'hbpush.utils', 'hbpush.pubsub'), scripts=('bin/hbpushd',), - dependency_links= ('http://github.com/facebook/tornado/tarball/b8271f94434208646eeec9cf33da703d97c5364e#egg=tornado-0.2', - 'http://github.com/clement/brukva/tarball/bff451511a3cc09cd52bebcf6372a59d36567827#egg=brukva-0.0.1',), + dependency_links= ('http://github.com/clement/brukva/tarball/bff451511a3cc09cd52bebcf6372a59d36567827#egg=brukva-0.0.1',), setup_requires=('nose>=0.11',), - install_requires=('PyYAML', 'brukva==0.0.1', 'tornado==0.2'), - requires=('PyYAML', 'brukva(==0.0.1)', 'tonardo(==0.2)'), + install_requires=('PyYAML', 'brukva>=0.0.1', 'tornado>0.2'), + requires=('PyYAML', 'brukva(>=0.0.1)', 'tonardo(>0.2)'), ) diff --git a/tests/mocks.py b/tests/mocks.py index 6da8acf..89bb896 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,6 +1,10 @@ from hbpush.pubsub.publisher import Publisher from hbpush.pubsub.subscriber import Subscriber, LongPollingSubscriber -from tornado.httpserver import HTTPHeaders, HTTPRequest +from tornado.httpserver import HTTPRequest +try: + from tornado.httpserver import HTTPHeaders +except ImportError: + from tornado.httputil import HTTPHeaders from tornado.web import HTTPError from tornado.ioloop import IOLoop