From 9404ed305cb0ec103b095fbe060aacf8e31ac7ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikl=C3=B3s=20Homolya?= Date: Thu, 12 Jul 2018 21:19:31 +0200 Subject: [PATCH 1/3] initial strip --- .binstar.yml | 13 --- .travis.yml | 6 -- LICENSE | 6 +- condarecipe/build.sh | 3 - condarecipe/meta.yaml | 32 ------ doc/Makefile | 157 --------------------------- doc/source/coffee.rst | 67 ------------ doc/source/conf.py | 245 ------------------------------------------ doc/source/index.rst | 22 ---- requirements.txt | 3 - setup.cfg | 8 +- setup.py | 5 +- 12 files changed, 6 insertions(+), 561 deletions(-) delete mode 100644 .binstar.yml delete mode 100755 condarecipe/build.sh delete mode 100644 condarecipe/meta.yaml delete mode 100644 doc/Makefile delete mode 100644 doc/source/coffee.rst delete mode 100644 doc/source/conf.py delete mode 100644 doc/source/index.rst diff --git a/.binstar.yml b/.binstar.yml deleted file mode 100644 index cc8b820d..00000000 --- a/.binstar.yml +++ /dev/null @@ -1,13 +0,0 @@ -package: coffee -user: firedrakeproject - -platform: - - linux-64 - -engine: - - python=2.7 - -script: - - conda build -q condarecipe - -build_targets: conda diff --git a/.travis.yml b/.travis.yml index 9aa92939..12c3ed06 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,15 +5,9 @@ language: python python: - "3.5" -addons: - apt: - packages: - - glpk-utils - before_install: - pip install -r requirements.txt - pip install flake8 - - pip install flake8-future-import install: - python setup.py install diff --git a/LICENSE b/LICENSE index 71fd8745..d2896c52 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ -Copyright (c) 2012, Imperial College London and others. Please see the -AUTHORS file in the main source directory for a full list of copyright -holders. All rights reserved. +Copyright (c) 2012-2018, Imperial College London and others. +Please see the AUTHORS file in the main source directory for a full +list of copyright holders. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/condarecipe/build.sh b/condarecipe/build.sh deleted file mode 100755 index 8e25a145..00000000 --- a/condarecipe/build.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -$PYTHON setup.py install diff --git a/condarecipe/meta.yaml b/condarecipe/meta.yaml deleted file mode 100644 index 9087211a..00000000 --- a/condarecipe/meta.yaml +++ /dev/null @@ -1,32 +0,0 @@ -package: - name: coffee - version: {{ environ.get('GIT_DESCRIBE_TAG','') }} - -source: - path: .. - -build: - number: {{ environ.get('GIT_DESCRIBE_NUMBER', 0) }} - -requirements: - build: - - python - - networkx - - run: - - python - - numpy - - networkx - -test: - requires: - - pytest - - flake8 - commands: - - py.test {{ os.path.join(environ.get('SRC_DIR'), 'tests') }} -v - - flake8 {{ environ.get('SRC_DIR') }} - -about: - home: http://www.firedrakeproject.org - license: BSD 3-clause - summary: COFFEE - COmpiler For Fast Expression Evaluation diff --git a/doc/Makefile b/doc/Makefile deleted file mode 100644 index 7e27156b..00000000 --- a/doc/Makefile +++ /dev/null @@ -1,157 +0,0 @@ -# Makefile for Sphinx documentation -# - -# You can set these variables from the command line. -APIDOCOPTS = -f -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -BUILDDIR = build - -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext - -help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - -apidoc: - sphinx-apidoc ../coffee -o source/ -T $(APIDOCOPTS) - -clean: - -rm -rf $(BUILDDIR)/* - -html: apidoc - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: apidoc - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: apidoc - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: apidoc - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: apidoc - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: apidoc - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: apidoc - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/COFFEE.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/COFFEE.qhc" - -devhelp: apidoc - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/COFFEE" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/COFFEE" - @echo "# devhelp" - -epub: apidoc - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: apidoc - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: apidoc - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: apidoc - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: apidoc - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -texinfo: apidoc - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -info: apidoc - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -gettext: apidoc - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -changes: apidoc - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -linkcheck: apidoc - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." - -doctest: apidoc - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." diff --git a/doc/source/coffee.rst b/doc/source/coffee.rst deleted file mode 100644 index 966e827e..00000000 --- a/doc/source/coffee.rst +++ /dev/null @@ -1,67 +0,0 @@ -coffee Package -============== - -:mod:`autotuner` Module ------------------------ - -.. automodule:: coffee.autotuner - :members: - :undoc-members: - :show-inheritance: - -:mod:`base` Module ------------------- - -.. automodule:: coffee.base - :members: - :undoc-members: - :show-inheritance: - -:mod:`linear_algebra` Module ----------------------------- - -.. automodule:: coffee.linear_algebra - :members: - :undoc-members: - :show-inheritance: - -:mod:`optimizer` Module ------------------------ - -.. automodule:: coffee.optimizer - :members: - :undoc-members: - :show-inheritance: - -:mod:`plan` Module ------------------- - -.. automodule:: coffee.plan - :members: - :undoc-members: - :show-inheritance: - -:mod:`utils` Module -------------------- - -.. automodule:: coffee.utils - :members: - :undoc-members: - :show-inheritance: - -:mod:`vectorizer` Module ------------------------- - -.. automodule:: coffee.vectorizer - :members: - :undoc-members: - :show-inheritance: - -:mod:`version` Module ---------------------- - -.. automodule:: coffee.version - :members: - :undoc-members: - :show-inheritance: - diff --git a/doc/source/conf.py b/doc/source/conf.py deleted file mode 100644 index 11ed384e..00000000 --- a/doc/source/conf.py +++ /dev/null @@ -1,245 +0,0 @@ -# -*- coding: utf-8 -*- -# -# COFFEE documentation build configuration file, created by -# sphinx-quickstart on Tue Sep 30 11:25:59 2014. -# -# This file is execfile()d with the current directory set to its containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import sys -import os - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('../..')) - -# -- General configuration ----------------------------------------------------- - -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be extensions -# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc'] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix of source filenames. -source_suffix = '.rst' - -# The encoding of source files. -#source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = u'COFFEE' -copyright = u'2014, Fabio Luporini' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -execfile("../../coffee/version.py") -version = '%d.%d' % __version_info__[0:2] # noqa: pulled from coffee/version.py -# The full version, including alpha/beta/rc tags. -release = __version__ # noqa: pulled from coffee/version.py - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -#language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = [] - -# The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# -- Options for HTML output --------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = 'default' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -#html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -#html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -#html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Output file base name for HTML help builder. -htmlhelp_basename = 'COFFEEdoc' - - -# -- Options for LaTeX output -------------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - #'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - #'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - #'preamble': '', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, author, documentclass [howto/manual]). -latex_documents = [ - ('index', 'COFFEE.tex', u'COFFEE Documentation', - u'Fabio Luporini', 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True - - -# -- Options for manual page output -------------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'coffee', u'COFFEE Documentation', - [u'Fabio Luporini'], 1) -] - -# If true, show URL addresses after external links. -#man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------------ - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ('index', 'COFFEE', u'COFFEE Documentation', - u'Fabio Luporini', 'COFFEE', 'One line description of project.', - 'Miscellaneous'), -] - -# Documents to append as an appendix to all manuals. -#texinfo_appendices = [] - -# If false, no module index is generated. -#texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' diff --git a/doc/source/index.rst b/doc/source/index.rst deleted file mode 100644 index bc325b8e..00000000 --- a/doc/source/index.rst +++ /dev/null @@ -1,22 +0,0 @@ -.. COFFEE documentation master file, created by - sphinx-quickstart on Tue Sep 30 11:25:59 2014. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to COFFEE's documentation! -================================== - -Contents: - -.. toctree:: - :maxdepth: 2 - - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - diff --git a/requirements.txt b/requirements.txt index 24662ec5..24ce15ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1 @@ -PuLP -networkx numpy -six diff --git a/setup.cfg b/setup.cfg index 9478c194..8741a32e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,3 @@ [flake8] -ignore = - E501,F403,F405,E226,E265,E731,E402,E266,F999, - FI14,FI54, - FI50,FI51,FI53 -exclude = .git,,__pycache__,build,dist,doc/source/conf.py -min-version = 2.7 +ignore = E501,F403,F405,E226,E265,E731,E402,E266,F999 +exclude = .git,__pycache__,build,dist,doc/source/conf.py diff --git a/setup.py b/setup.py index 098a175d..0a42eddb 100644 --- a/setup.py +++ b/setup.py @@ -31,8 +31,6 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import absolute_import, print_function, division - try: from setuptools import setup except ImportError: @@ -45,11 +43,10 @@ author='Fabio Luporini', author_email='f.luporini12@imperial.ac.uk', url='https://github.com/coneoproject/COFFEE', - install_requires=["networkx"], classifiers=['Development Status :: 3 - Alpha', 'Intended Audience :: Developers', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', - 'Programming Language :: Python :: 2.7'], + 'Programming Language :: Python :: 3.5'], packages=['coffee', 'coffee.visitors']) From c3bf14724547b9f7e73f572a9ec9fd9b74ecf61b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikl=C3=B3s=20Homolya?= Date: Thu, 12 Jul 2018 21:43:07 +0200 Subject: [PATCH 2/3] strip dead code --- coffee/cse.py | 508 ------------------------- coffee/expander.py | 118 ------ coffee/factorizer.py | 236 ------------ coffee/hoister.py | 351 ------------------ coffee/optimizer.py | 512 +------------------------- coffee/rewriter.py | 639 -------------------------------- coffee/scheduler.py | 856 ------------------------------------------- coffee/version.py | 4 - 8 files changed, 1 insertion(+), 3223 deletions(-) delete mode 100644 coffee/cse.py delete mode 100644 coffee/expander.py delete mode 100644 coffee/factorizer.py delete mode 100644 coffee/hoister.py delete mode 100644 coffee/rewriter.py delete mode 100644 coffee/scheduler.py delete mode 100644 coffee/version.py diff --git a/coffee/cse.py b/coffee/cse.py deleted file mode 100644 index dd79c3f2..00000000 --- a/coffee/cse.py +++ /dev/null @@ -1,508 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six import iterkeys, iteritems, itervalues -from six.moves import zip - -import operator - -from .base import * -from .utils import * -from coffee.visitors import EstimateFlops -from .expression import MetaExpr -from .logger import log, COST_MODEL -from functools import reduce - - -class Temporary(object): - - """A Temporary stores useful information for a statement (e.g., an Assign - or an AugmentedAssig) that computes a temporary variable; that is, a variable - that is read in more than one place.""" - - def __init__(self, node, main_loop, nest, linear_reads_costs=None): - self.level = -1 - self.pushed = False - self.readby = [] - - self.node = node - self.main_loop = main_loop - self.nest = nest - self.linear_reads_costs = linear_reads_costs or OrderedDict() - self.flops = EstimateFlops().visit(node) - - @property - def name(self): - return self.symbol.symbol if self.symbol else None - - @property - def rank(self): - return self.symbol.rank if self.symbol else None - - @property - def linearity_degree(self): - return len(self.main_linear_loops) - - @property - def symbol(self): - if isinstance(self.node, Writer): - return self.node.lvalue - elif isinstance(self.node, Symbol): - return self.node - else: - return None - - @property - def expr(self): - if isinstance(self.node, Writer): - return self.node.rvalue - else: - return None - - @property - def urepr(self): - return self.symbol.urepr - - @property - def reads(self): - return Find(Symbol).visit(self.expr)[Symbol] if self.expr else [] - - @property - def linear_reads(self): - return list(iterkeys(self.linear_reads_costs)) if self.linear_reads_costs else [] - - @property - def loops(self): - return list(zip(*self.nest))[0] - - @property - def main_linear_loops(self): - return [l for l in self.main_loops if l.is_linear] - - @property - def main_linear_nest(self): - return [(l, p) for l, p in self.main_nest if l in self.linear_loops] - - @property - def main_loops(self): - index = self.loops.index(self.main_loop) - return [l for l in self.loops[:index + 1]] - - @property - def main_nest(self): - return [(l, p) for l, p in self.nest if l in self.main_loops] - - @property - def flops_projection(self): - # #muls + #sums - nmuls = len(self.linear_reads) - return (nmuls) + (nmuls - 1) - - @property - def is_ssa(self): - return self.symbol not in self.readby - - @property - def is_static_init(self): - return isinstance(self.expr, ArrayInit) - - @property - def is_increment(self): - return isinstance(self.node, Incr) - - @property - def reductions(self): - return [l for l in self.main_loops if l.dim not in self.rank] - - @property - def nreductions(self): - return len(self.reductions) - - def niters(self, mode='all', handle=None): - assert mode in ['all', 'outer', 'nonlinear', 'in', 'out'] - handle = handle or [] - limit = self.loops.index(self.main_loop) - loops = self.loops[:limit + 1] - if mode == 'all': - sizes = [l.size for l in loops] - elif mode == 'outer': - sizes = [l.size for l in loops if l is not self.main_loop] - elif mode == 'nonlinear': - sizes = [l.size for l in loops if not l.is_linear] - elif mode == 'in': - sizes = [l.size for l in loops if l.dim in handle] - else: - sizes = [l.size for l in loops if l.dim not in handle] - return reduce(operator.mul, sizes, 1) - - def depends(self, others): - """Return True if ``self`` reads a temporary or is read by a temporary - that appears in the iterator ``others``, False otherwise.""" - dependencies = self.linear_reads + self.reads - for t in others: - if any(s.urepr == t.urepr for s in dependencies): - return True - return False - - def reconstruct(self): - temporary = Temporary(self.node, self.main_loop, self.nest, - OrderedDict(self.linear_reads_costs)) - temporary.level = self.level - temporary.readby = list(self.readby) - return temporary - - def __str__(self): - return "%s: level=%d, flops/iter=%d, linear_reads=[%s], isread=[%s]" % \ - (self.symbol, self.level, self.flops, - ", ".join([str(i) for i in self.linear_reads]), - ", ".join([str(i) for i in self.readby])) - - -class CSEUnpicker(object): - - """Analyze loops in which some temporary variables are computed and, applying - a cost model, decides whether to leave a temporary intact or inline it for - creating factorization and code motion opportunities. - - The cost model exploits one particular property of loops, namely linearity in - symbols (further information concerning loop linearity is available in the module - ``expression.py``).""" - - def __init__(self, exprs, header, hoisted): - self.exprs = exprs - self.header = header - self.hoisted = hoisted - - @property - def type(self): - return list(itervalues(self.exprs))[0].type - - @property - def linear_dims(self): - return list(itervalues(self.exprs))[0].linear_dims - - def _push_temporaries(self, temporaries, trace, global_trace, ra, decls): - - def is_pushable(temporary, temporaries): - # To be pushable ... - if not temporary.is_ssa: - # ... must be written only once - return False - if not temporary.readby: - # ... must actually be read by some other temporaries (the output - # variables are not) - return False - if temporary.is_static_init: - # ... its rvalue must not be an array initializer - return False - if temporary.depends(temporaries): - # ... it cannot depend on other temporaries in the same level - return False - pushed_in = [global_trace.get(rb.urepr) for rb in temporary.readby] - pushed_in = set(rb.main_loop.children[0] for rb in pushed_in if rb) - reads = [s for s in temporary.reads if not s.is_number] - for s in reads: - # ... all the read temporaries must be accessible in the loops in which - # they will be pushed - if s.urepr in global_trace and global_trace[s.urepr].pushed: - continue - if s.symbol not in decls: - continue - if any(l not in ra[decls[s.symbol]] for l in pushed_in): - return False - return True - - to_replace, modified_temporaries = {}, OrderedDict() - for t in temporaries: - # Track temporaries to be pushed from /level-1/ into the later /level/s - if not is_pushable(t, temporaries): - continue - to_replace[t.symbol] = t.expr or t.symbol - for rb in t.readby: - modified_temporaries[rb.urepr] = trace.get(rb.urepr, - global_trace[rb.urepr]) - # The temporary is going to be pushed, so we can remove it as long as - # it is not needed somewhere else - if t.node in t.main_loop.body and\ - all(rb.urepr in global_trace for rb in t.readby): - global_trace[t.urepr].pushed = True - t.main_loop.body.remove(t.node) - - # Transform the AST (note: node replacement must happen in the order - # in which the temporaries have been encountered) - modified_temporaries = sorted(modified_temporaries.values(), - key=lambda t: list(iterkeys(global_trace)).index(t.urepr)) - for t in modified_temporaries: - ast_replace(t.node, to_replace, copy=True) - replaced = [t.urepr for t in to_replace.keys()] - - # Update the temporaries - for t in modified_temporaries: - for r, c in list(iteritems(t.linear_reads_costs)): - if r.urepr in replaced: - t.linear_reads_costs.pop(r) - r_linear_reads_costs = global_trace[r.urepr].linear_reads_costs - for p, p_c in r_linear_reads_costs.items() or [(r, 0)]: - t.linear_reads_costs[p] = c + p_c - - def _transform_temporaries(self, temporaries, decls): - from .rewriter import ExpressionRewriter - - # Never attempt to transform the main expression - temporaries = [t for t in temporaries if t.node not in self.exprs] - - lda = loops_analysis(self.header, key='symbol', value='dim') - - # Expand + Factorize - rewriters = OrderedDict() - for t in temporaries: - expr_info = MetaExpr(self.type, t.main_loop.block, t.main_nest) - ew = ExpressionRewriter(t.node, expr_info, self.header, self.hoisted) - ew.replacediv() - ew.expand(mode='all', lda=lda) - ew.reassociate(lambda i: all(r != t.main_loop.dim for r in lda[i.symbol])) - ew.factorize(mode='adhoc', adhoc={i.urepr: [] for i in t.linear_reads}, lda=lda) - rewriters[t] = ew - - lda = loops_analysis(self.header, value='dim') - - # Code motion - for t, ew in rewriters.items(): - ew.licm(mode='only_outlinear', lda=lda, global_cse=True) - if t.linearity_degree > 1: - ew.licm(mode='only_linear', lda=lda) - - # Keep track of new declarations (recomputation might otherwise be too expensive) - decls.update(OrderedDict([(k, v.decl) for k, v in self.hoisted.items()])) - - def _analyze_expr(self, expr, loop, lda, decls): - reads = Find(Symbol).visit(expr)[Symbol] - reads = [s for s in reads if s.symbol in decls] - syms = [s for s in reads if any(d in loop.dim for d in lda[s])] - - linear_reads_costs = OrderedDict() - - def wrapper(node, found=0): - if isinstance(node, Symbol): - if node in syms: - linear_reads_costs.setdefault(node, 0) - linear_reads_costs[node] += found - return - elif isinstance(node, (EmptyStatement, ArrayInit)): - return - elif isinstance(node, (Prod, Div)): - found += 1 - operands = list(zip(*explore_operator(node)))[0] - for o in operands: - wrapper(o, found) - wrapper(expr) - - return reads, linear_reads_costs - - def _analyze_loop(self, loop, nest, lda, global_trace, decls): - linear_dims = [l.dim for l, _ in nest if l.is_linear] - - trace = OrderedDict() - for node in loop.body: - if not isinstance(node, Writer): - not_ssa = [trace[w] for w in in_written(node, key='urepr') if w in trace] - for t in not_ssa: - t.readby.append(t.symbol) - continue - reads, linear_reads_costs = self._analyze_expr(node.rvalue, loop, lda, decls) - affected = [s for s in reads if any(i in linear_dims for i in lda[s])] - for s in affected: - if s.urepr in global_trace: - temporary = global_trace[s.urepr] - temporary.readby.append(node.lvalue) - temporary = temporary.reconstruct() - temporary.level = -1 - trace[s.urepr] = temporary - else: - temporary = trace.setdefault(s.urepr, Temporary(s, loop, nest)) - temporary.readby.append(node.lvalue) - new_temporary = Temporary(node, loop, nest, linear_reads_costs) - new_temporary.level = max([trace[s.urepr].level for s - in new_temporary.linear_reads] or [-2]) + 1 - trace[node.lvalue.urepr] = new_temporary - - return trace - - def _group_by_level(self, trace): - levels = defaultdict(list) - - for temporary in trace.values(): - levels[temporary.level].append(temporary) - return levels - - def _cost_cse(self, levels, bounds=None): - if bounds is not None: - lb, up = bounds[0], bounds[1] + 1 - levels = {i: levels[i] for i in range(lb, up)} - cost = 0 - for level, temporaries in levels.items(): - cost += sum(t.flops*t.niters('all') for t in temporaries) - return cost - - def _cost_fact(self, trace, levels, lda, bounds): - # Check parameters - assert len(bounds) == 2 and bounds[1] >= bounds[0] - assert bounds[0] in levels.keys() and bounds[1] in levels.keys() - - # Determine current costs of individual loop regions - input_cost = self._cost_cse(levels, (min(levels.keys()), max(levels.keys()))) - uptolevel_cost, post_cse_cost = input_cost, input_cost - level_inloop_cost, total_outloop_cost = 0, 0 - - # We are going to modify a copy of the temporaries dict - new_trace = OrderedDict() - for s, t in trace.items(): - new_trace[s] = t.reconstruct() - - # Cost induced by the untransformed temporaries - pre_cse_cost = self._cost_cse(levels, (min(levels.keys()), bounds[0])) - - best = (bounds[0], bounds[0], uptolevel_cost) - fact_levels = {k: v for k, v in levels.items() if k > bounds[0] and k <= bounds[1]} - for level, temporaries in sorted(fact_levels.items(), key=lambda i_j: i_j[0]): - level_inloop_cost = 0 - for t in temporaries: - # Compute the cost induced by /t/ in the outer loops after fact+licm - t_outloop_cost, linear_reads = 0, [] - for read, cost in t.linear_reads_costs.items(): - traced = new_trace.get(read.urepr) - if traced and traced.level >= bounds[0]: - handle = traced.linear_reads or [read] - if cost: - for i in handle: - # One prod in the closest linear loop - t_outloop_cost += t.niters('out', lda[i]) - # The rest falls outside of the linear loops - t_outloop_cost += (cost - 1)*t.niters('nonlinear') - else: - handle = [read] - linear_reads.extend(handle) - factors = list(itervalues({as_urepr(i): i for i in linear_reads})) - # Take into account the increased number of sums (due to fact) - hoist_region = set.union(*[lda[i] for i in factors]) - niters = t.niters('out', hoist_region) - t_outloop_cost += (len(linear_reads) - len(factors))*niters - total_outloop_cost += t_outloop_cost - - # Compute the cost induced by /t/ in the main loop after fact+licm - # We end up creating n prods and n -1 sums - t_inloop_cost = 2*len(factors) - 1 - level_inloop_cost += t_inloop_cost*t.niters('all') - - # Take into account any hoistable reductions - if t.is_increment: - for i in factors: - handle = [l.dim for l in t.reductions if l.dim not in i.rank] - level_inloop_cost -= t.niters('all') - t.niters('out', handle) - - # Keep the trace up-to-date - linear_reads_costs = {i: 1 for i in factors} - new_trace[t.urepr].linear_reads_costs = linear_reads_costs - - # Some temporaries within levels < /level/ might also appear in - # subsequent loops or levels beyond /level/, so they still contribute - # to the operation count - for t in list(flatten([levels[j] for j in range(level)])): - if any(rb.urepr not in new_trace for rb in t.readby) or \ - any(new_trace[rb.urepr].level > level for rb in t.readby): - # Note: condition 1) is basically saying "if I'm read by - # a temporary that is not in this loop's trace, then I must - # be read in some other loops". - level_inloop_cost += \ - new_trace[t.urepr].flops_projection*t.niters('all') - - post_cse_cost = self._cost_cse(fact_levels, (level + 1, bounds[1])) - - # Compute the total cost - total_inloop_cost = pre_cse_cost + level_inloop_cost + post_cse_cost - uptolevel_cost = total_outloop_cost + total_inloop_cost - - # Update the best alternative - if uptolevel_cost < best[2]: - best = (bounds[0], level, uptolevel_cost) - - log('[CSE]: unpicking between [%d, %d]:' % (bounds[0], level), COST_MODEL) - log(' flops: %d -> %d (hoist=%d, preCSE=%d, fact=%d, postCSE=%d)' % - (input_cost, uptolevel_cost, total_outloop_cost, pre_cse_cost, - level_inloop_cost, post_cse_cost), COST_MODEL) - - return best - - def unpick(self): - # Collect all necessary info - info = visit(self.header, info_items=['decls', 'fors']) - decls, fors = info['decls'], info['fors'] - lda = loops_analysis(self.header, value='dim') - - # Collect all loops to be analyzed - nests = OrderedDict() - for nest in fors: - for loop, parent in nest: - if loop.is_linear: - nests[loop] = nest - - # Analyze loops - global_trace = OrderedDict() - mapper = OrderedDict() - for loop, nest in nests.items(): - trace = self._analyze_loop(loop, nest, lda, global_trace, decls) - if trace: - mapper[loop] = trace - global_trace.update(trace) - - for loop, trace in mapper.items(): - # Compute the best cost alternative - levels = self._group_by_level(trace) - min_level, max_level = min(levels.keys()), max(levels.keys()) - current_cost = self._cost_cse(levels, (min_level, max_level)) - global_best = (min_level, min_level, current_cost) - for i in sorted(levels.keys()): - local_best = self._cost_fact(trace, levels, lda, (i, max_level)) - if local_best[2] < global_best[2]: - global_best = local_best - - log("-- Best: [%d, %d] (cost=%d) --" % global_best, COST_MODEL) - - # Transform the loop - for i in range(global_best[0] + 1, global_best[1] + 1): - ra = reachability_analysis(self.header) - self._push_temporaries(levels[i-1], trace, global_trace, ra, decls) - self._transform_temporaries(levels[i], decls) - - cleanup(self.header) diff --git a/coffee/expander.py b/coffee/expander.py deleted file mode 100644 index 6d2e312a..00000000 --- a/coffee/expander.py +++ /dev/null @@ -1,118 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division - -import itertools - -from .base import * -from .utils import * -from .exceptions import UnexpectedNode - - -class Expander(object): - - """Expand the products in an expression according to a set of rules. For a - comprehensive list of possible rules, refer to the documentation of the - corresponding wrapper function ``expand`` in ``ExpressionRewriter``.""" - - # Constants used by the /expand/ method to charaterize sub-expressions: - GROUP = 0 # Expression /will/ not trigger expansion - EXPAND = 1 # Expression /could/ be expanded - - def __init__(self, stmt): - self.stmt = stmt - - def _build(self, exp, grp, expansions): - """Create a node for the expansion and keep track of it.""" - expansion = Prod(exp, dcopy(grp)) - # Track the new expansion - expansions.append(expansion) - # Untrack any expansions occured in children nodes - if grp in expansions: - expansions.remove(grp) - return expansion - - def _expand(self, node, parent, expansions): - if isinstance(node, Symbol): - return ([node], self.EXPAND) if self.should_expand(node) \ - else ([node], self.GROUP) - - elif isinstance(node, (Div, Ternary, FunCall)): - # Try to expand /within/ the children, but then return saying "I'm not - # expandable any further" - for n in node.children: - self._expand(n, node, expansions) - return ([node], self.GROUP) - - elif isinstance(node, Prod): - l_exps, l_type = self._expand(node.left, node, expansions) - r_exps, r_type = self._expand(node.right, node, expansions) - if l_type == self.GROUP and r_type == self.GROUP: - return ([node], self.GROUP) - # At least one child is expandable (marked as EXPAND), whereas the - # other could either be expandable as well or groupable (marked - # as GROUP): so we can perform the expansion - groupable = l_exps if l_type == self.GROUP else r_exps - expandable = r_exps if l_type == self.GROUP else l_exps - to_replace = OrderedDict() - for exp, grp in itertools.product(expandable, groupable): - expansion = self._build(exp, grp, expansions) - to_replace.setdefault(exp, []).append(expansion) - ast_replace(node, {k: ast_make_expr(Sum, v) for k, v in to_replace.items()}, - copy=False, mode='symbol') - # Update the parent node, since an expression has just been expanded - expanded = node.right if l_type == self.GROUP else node.left - parent.children[parent.children.index(node)] = expanded - return (list(flatten(to_replace.values())) or [expanded], self.EXPAND) - - elif isinstance(node, (Sum, Sub)): - l_exps, l_type = self._expand(node.left, node, expansions) - r_exps, r_type = self._expand(node.right, node, expansions) - if l_type == self.EXPAND and r_type == self.EXPAND and isinstance(node, Sum): - return (l_exps + r_exps, self.EXPAND) - elif l_type == self.EXPAND and r_type == self.EXPAND and isinstance(node, Sub): - return (l_exps + [Neg(r) for r in r_exps], self.EXPAND) - else: - return ([node], self.GROUP) - - else: - raise UnexpectedNode("Expansion: %s" % str(node)) - - def expand(self, should_expand, **kwargs): - expressions = kwargs.get('subexprs', [(self.stmt.rvalue, self.stmt)]) - - self.should_expand = should_expand - - for node, parent in expressions: - self._expand(node, parent, []) diff --git a/coffee/factorizer.py b/coffee/factorizer.py deleted file mode 100644 index 48c2d7a4..00000000 --- a/coffee/factorizer.py +++ /dev/null @@ -1,236 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division - -import operator - -from .base import * -from .utils import * -from functools import reduce - - -class Term(object): - """A Term represents a product between 'operands' and 'factors'. In a - product /a*(b+c)/, /a/ is the 'operand', while /b/ and /c/ are the 'factors'. - The symbol /+/ is the 'op' of the Term. - """ - - def __init__(self, operands, factors=None, op=None): - self.operands = operands - self.factors = factors or [] - self.op = op - - @property - def operands_ast(self): - return ast_make_expr(Prod, self.operands) - - @property - def factors_ast(self): - return ast_make_expr(self.op, self.factors) - - @property - def generate_ast(self): - if len(self.factors) == 0: - return self.operands_ast - elif len(self.operands) == 0: - return self.factors_ast - elif len(self.factors) == 1 and \ - all(isinstance(i, Symbol) and i.symbol == 1.0 for i in self.factors): - return self.operands_ast - else: - return Prod(self.operands_ast, self.factors_ast) - - def add_operands(self, operands): - for o in operands: - if o not in self.operands: - self.operands.append(o) - - def remove_operands(self, operands): - for o in operands: - if o in self.operands: - self.operands.remove(o) - - def add_factors(self, factors): - for f in factors: - if f not in self.factors: - self.factors.append(f) - - def remove_factors(self, factors): - for f in factors: - if f in self.factors: - self.factors.remove(f) - - @staticmethod - def process(symbols, should_factorize, op=None): - operands = [s for s in symbols if should_factorize(s)] - factors = [s for s in symbols if not should_factorize(s)] - return Term(operands, factors, op) - - -class Factorizer(object): - - """Factorize terms in an expression according to a set of rules. For a - comprehensive list of possible rules, refer to the documentation of the - corresponding wrapper function ``factorize`` in ``ExpressionRewriter``.""" - - def __init__(self, stmt): - self.stmt = stmt - - def _simplify_sum(self, terms): - unique_terms = OrderedDict() - for t in terms: - unique_terms.setdefault(str(t.generate_ast), list()).append(t) - - for t_repr, t_list in unique_terms.items(): - occurrences = len(t_list) - unique_terms[t_repr] = t_list[0] - if occurrences > 1: - unique_terms[t_repr].add_factors([Symbol(occurrences)]) - - terms[:] = unique_terms.values() - - def _heuristic_collection(self, terms): - if not self.heuristic or any(t.operands for t in terms): - return - tracker = OrderedDict() - for t in terms: - symbols = [s for s in t.factors if isinstance(s, Symbol)] - for s in symbols: - tracker.setdefault(s.urepr, []).append(t) - reverse_tracker = OrderedDict() - for s, ts in tracker.items(): - reverse_tracker.setdefault(tuple(ts), []).append(s) - # 1) At least one symbol appearing in all terms: use that as operands ... - operands = [(ts, s) for ts, s in reverse_tracker.items() if ts == tuple(terms)] - # 2) ... Or simply pick operands greedily - if not operands: - handled = set() - for ts, s in reverse_tracker.items(): - if len(ts) > 1 and all(t not in handled for t in ts): - operands.append((ts, s)) - handled |= set(ts) - for ts, s in operands: - for t in ts: - new_operands = [i for i in t.factors if - isinstance(i, Symbol) and i.urepr in s] - t.remove_factors(new_operands) - t.add_operands(new_operands) - - def _premultiply_symbols(self, symbols): - floats = [s for s in symbols if isinstance(s.symbol, (int, float))] - if len(floats) > 1: - other_symbols = [s for s in symbols if s not in floats] - prem = reduce(operator.mul, [s.symbol for s in floats], 1.0) - prem = [Symbol(prem)] if prem not in [1, 1.0] else [] - return prem + other_symbols - else: - return symbols - - def _filter(self, factorizable_term): - o = factorizable_term.operands_ast - grp = self.adhoc.get(o.urepr, []) if isinstance(o, Symbol) else [] - if not grp: - return False - for f in factorizable_term.factors: - symbols = Find(Symbol).visit(f)[Symbol] - if any(s.urepr in grp for s in symbols): - return False - return True - - def _factorize(self, node, parent): - if isinstance(node, Symbol): - return Term.process([node], self.should_factorize) - - elif isinstance(node, (FunCall, Div)): - # Try to factorize /within/ the children, but then return saying - # "I'm not factorizable any further" - for n in node.children: - self._factorize(n, node) - return Term([], [node]) - - elif isinstance(node, Prod): - children = explore_operator(node) - symbols = [n for n, _ in children if isinstance(n, Symbol)] - other_nodes = [(n, p) for n, p in children if n not in symbols] - symbols = self._premultiply_symbols(symbols) - factorized = Term.process(symbols, self.should_factorize, Prod) - terms = [self._factorize(n, p) for n, p in other_nodes] - for t in terms: - factorized.add_operands(t.operands) - factorized.add_factors(t.factors) - return factorized - - # The fundamental case is when /node/ is a Sum (or Sub, equivalently). - # Here, we try to factorize the terms composing the operation - elif isinstance(node, (Sum, Sub)): - children = explore_operator(node) - # First try to factorize within /node/'s children - terms = [self._factorize(n, p) for n, p in children] - # Check if it's possible to aggregate operations - # Example: replace (a*b)+(a*b) with 2*(a*b) - self._simplify_sum(terms) - # No global factorization rule is used, so just try to maximize - # factorization within /this/ Sum/Sub - self._heuristic_collection(terms) - # Finally try to factorize some of the operands composing the operation - factorized = OrderedDict() - for t in terms: - operand = [t.operands_ast] if t.operands else [] - factor = [t.factors_ast] if t.factors else [Symbol(1.0)] - factorizable_term = Term(operand, factor, node.__class__) - if self._filter(factorizable_term): - # Skip - factorized[t] = t - else: - # Do factorize - _t = factorized.setdefault(str(t.operands_ast), factorizable_term) - _t.add_factors(factor) - factorized = [t.generate_ast for t in factorized.values()] - factorized = ast_make_expr(Sum, factorized) - parent.children[parent.children.index(node)] = factorized - return Term([], [factorized]) - - else: - return Term([], [node]) - - def factorize(self, should_factorize, **kwargs): - expressions = kwargs.get('subexprs', [(self.stmt.rvalue, self.stmt)]) - adhoc = kwargs.get('adhoc', {}) - - self.should_factorize = should_factorize - self.adhoc = adhoc if any(v for v in adhoc.values()) else {} - self.heuristic = kwargs.get('heuristic', False) - - for node, parent in expressions: - self._factorize(node, parent) diff --git a/coffee/hoister.py b/coffee/hoister.py deleted file mode 100644 index 3c15dfe3..00000000 --- a/coffee/hoister.py +++ /dev/null @@ -1,351 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six.moves import zip - -from .base import * -from .utils import * - - -class Extractor(object): - - EXT = 0 # expression marker: extract - STOP = 1 # expression marker: do not extract - - def __init__(self, stmt, expr_info, should_extract): - self.stmt = stmt - self.expr_info = expr_info - self.should_extract = should_extract - - def _apply_cse(self): - # Find common sub-expressions heuristically looking at binary terminal - # operations (i.e., a terminal has two Symbols as children). This may - # induce more sweeps of extraction to find all common sub-expressions, - # but at least it keeps the algorithm simple and probably more effective - finder = Find(Symbol, with_parent=True) - for dep, subexprs in self.extracted.items(): - cs = OrderedDict() - values = [finder.visit(e)[Symbol] for e in subexprs] - binexprs = list(zip(*flatten(values)))[1] - binexprs = [b for b in binexprs if binexprs.count(b) > 1] - for b in binexprs: - t = cs.setdefault(b.urepr, []) - if b not in t: - t.append(b) - cs = [v for k, v in cs.items() if len(v) > 1] - if cs: - self.extracted[dep] = list(flatten(cs)) - - def _try(self, node, dep): - if isinstance(node, Symbol): - return False - should_extract = self.should_extract(dep) - if should_extract or self._look_ahead: - dep = sorted(dep, key=lambda i: self.expr_info.dims.index(i)) - self.extracted.setdefault(tuple(dep), []).append(node) - return should_extract - - def _visit(self, node): - if isinstance(node, Symbol): - return (self._lda[node], self.EXT) - - elif isinstance(node, (FunCall, Ternary)): - arg_deps = [self._visit(n) for n in node.children] - dep = set(flatten([dep for dep, _ in arg_deps])) - info = self.EXT if all(i == self.EXT for _, i in arg_deps) else self.STOP - return (dep, info) - - else: - retval = [(n,) + self._visit(n) for n in node.children] - dep = set.union(*[d for _, d, _ in retval]) - dep = {d for d in dep if d in self.expr_info.dims} - if self.should_extract(dep) or self._look_ahead: - # Still a chance of finding a bigger expression - return (dep, self.EXT) - else: - for n, n_dep, n_info in retval: - if n_info == self.EXT and not isinstance(n, Symbol): - k = sorted(n_dep, key=lambda i: self.expr_info.dims.index(i)) - self.extracted.setdefault(tuple(k), []).append(n) - return (dep, self.STOP) - - def extract(self, look_ahead, lda, with_cse=False): - """Extract invariant subexpressions from /self.expr/.""" - self._lda = lda - self._look_ahead = look_ahead - self.extracted = OrderedDict() - - self._visit(self.stmt.rvalue) - if with_cse: - self._apply_cse() - - del self._lda - del self._look_ahead - - return self.extracted - - -class Hoister(object): - - # Temporary variables template - _template = "ct%d" - - def __init__(self, stmt, expr_info, header, hoisted): - """Initialize the Hoister.""" - self.stmt = stmt - self.expr_info = expr_info - self.header = header - self.hoisted = hoisted - - def _filter(self, dep, subexprs, make_unique=True, sharing=None): - """Filter hoistable subexpressions.""" - if make_unique: - # Uniquify expressions - subexprs = uniquify(subexprs) - - if sharing: - # Partition expressions such that expressions sharing the same - # set of symbols are in the same partition - if dep == self.expr_info.dims: - return [] - sharing = [str(s) for s in sharing] - partitions = defaultdict(list) - for e in subexprs: - symbols = tuple(set(str(s) for s in Find(Symbol).visit(e)[Symbol] - if str(s) in sharing)) - partitions[symbols].append(e) - for shared, partition in partitions.items(): - if len(partition) > len(shared): - subexprs = [e for e in subexprs if e not in partition] - - return subexprs - - def _is_hoistable(self, subexprs, loop): - """Return True if the sub-expressions provided in ``subexprs`` are - hoistable outside of ``loop``, False otherwise.""" - written = in_written(loop, 'symbol') - reads = Find.default_retval() - for e in subexprs: - Find(Symbol).visit(e, ret=reads) - reads = [s.symbol for s in reads[Symbol]] - return set.isdisjoint(set(reads), set(written)) - - def _locate(self, dep, subexprs, with_promotion=False): - # Start assuming no "real" hoisting can take place - # E.g.: for i {a[i]*(t1 + t2);} --> for i {t3 = t1 + t2; a[i]*t3;} - place, offset = self.expr_info.innermost_loop.block, self.stmt - - if with_promotion: - # Hoist outside a loop even though this doesn't result in any - # operation count reduction - should_jump = lambda l: True - else: - # "Standard" code motion case, i.e. moving /subexprs/ as far as - # possible in the loop nest such that dependencies are honored - should_jump = lambda l: l.dim not in dep - - loops = list(reversed(self.expr_info.loops)) - candidates = [l.block for l in loops[1:]] + [self.header] - - for loop, candidate in zip(loops, candidates): - if not self._is_hoistable(subexprs, loop): - break - if should_jump(loop): - place, offset = candidate, loop - - # Determine how much extra memory and whether clone loops are needed - jumped = loops[:candidates.index(place) + 1] - clone = tuple(l for l in reversed(jumped) if l.dim in dep) - - return place, offset, clone - - def extract(self, should_extract, **kwargs): - """Return a dictionary of hoistable subexpressions.""" - lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') - extractor = Extractor(self.stmt, self.expr_info, should_extract) - return extractor.extract(True, lda) - - def licm(self, should_extract, **kwargs): - """Perform generalized loop-invariant code motion.""" - max_sharing = kwargs.get('max_sharing', False) - with_promotion = kwargs.get('with_promotion', False) - iterative = kwargs.get('iterative', True) - lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') - global_cse = kwargs.get('global_cse', False) - - extractor = Extractor(self.stmt, self.expr_info, should_extract) - - extracted = True - while extracted: - extracted = extractor.extract(False, lda, global_cse) - for dep, subexprs in extracted.items(): - # 1) Filter subexpressions that will be hoisted - sharing = [] - if max_sharing: - sharing = uniquify([s for s, d in lda.items() if d == dep]) - subexprs = self._filter(dep, subexprs, sharing=sharing) - if not subexprs: - continue - - # 2) Determine the outermost loop where invariant expressions - # can be hoisted without breaking data dependencies. - place, offset, clone = self._locate(dep, subexprs, with_promotion) - - loop_size = tuple(l.size for l in clone) - loop_dim = tuple(l.dim for l in clone) - - # 3) Create the required new AST nodes - symbols, decls, stmts = [], [], [] - for e in subexprs: - already_hoisted = False - if global_cse and self.hoisted.get_symbol(e): - name = self.hoisted.get_symbol(e) - decl = self.hoisted[name].decl - if decl in place.children and \ - place.children.index(decl) < place.children.index(offset): - already_hoisted = True - if not already_hoisted: - name = self._template % (len(self.hoisted) + len(stmts)) - stmts.append(Assign(Symbol(name, loop_dim), dcopy(e))) - decls.append(Decl(self.expr_info.type, - Symbol(name, loop_size), - scope=LOCAL)) - symbols.append(Symbol(name, loop_dim)) - - # 4) Replace invariant sub-expressions with temporaries - replacements = ast_replace(self.stmt, dict(zip(subexprs, symbols))) - - # 5) Modify the AST adding the hoisted expressions - if clone: - outer_clone = ast_make_for(stmts, clone[-1]) - for l in reversed(clone[:-1]): - outer_clone = ast_make_for([outer_clone], l) - code = decls + [outer_clone] - clone = outer_clone - else: - code = decls + stmts - clone = None - offset = place.children.index(offset) - place.children[offset:offset] = code - - # 6) Track hoisted symbols and data dependencies - for i, j in zip(stmts, decls): - name = j.lvalue.symbol - self.hoisted[name] = (i, j, clone, place) - lda.update({s: set(dep) for s in replacements}) - - if not iterative: - break - - def trim(self, candidate, **kwargs): - """ - Remove unnecessary reduction loops from the expression loop nest. - Sometimes, reduction loops can be factored out in outer loops, thus - reducing the operation count, without breaking data dependencies. - """ - # Rule out unsafe cases - if not is_perfect_loop(self.expr_info.innermost_loop): - return - - # Find out all reducible symbols - lda = kwargs.get('lda') or loops_analysis(self.header) - reducible, other = [], [] - for i in summands(self.stmt.rvalue): - symbols = Find(Symbol).visit(i)[Symbol] - unavoidable = set.intersection(*[set(lda[s]) for s in symbols]) - if candidate in unavoidable: - return - reducible.extend([s.symbol for s in symbols if candidate in lda[s]]) - other.extend([s.symbol for s in symbols if candidate not in lda[s]]) - - # Make sure we do not break data dependencies - make_reduce = [] - writes = Find(Writer).visit(candidate) - for w in flatten(writes.values()): - if isinstance(w.rvalue, EmptyStatement): - continue - if any(s == w.lvalue.symbol for s in other): - return - if any(s == w.lvalue.symbol for s in reducible): - loop = lda[w.lvalue][-1] - make_reduce.append((w, loop)) - - assignments = [(w, p) for w, p in make_reduce if isinstance(w, Assign)] - loops, parents = zip(*self.expr_info.loops_info) - index = loops.index(candidate) - - # Perform a number of checks to ensure lifting reductions is safe - if not all(s in [w.lvalue.symbol for w, _ in make_reduce] for s in reducible): - return - if any(p != candidate and not is_perfect_loop(p) for w, p in make_reduce): - return - if any(candidate.dim in w.lvalue.rank for w, _ in assignments): - return - if any(set(loops[index + 1:]) & set(lda[w.lvalue]) for w, _ in make_reduce): - return - - # Inject the reductions into the AST - decls = visit(self.header, info_items=['decls'])['decls'] - for w, p in make_reduce: - name = self._template % len(self.hoisted) - reduction = Incr(Symbol(name, w.lvalue.rank, w.lvalue.offset), - ast_reconstruct(w.rvalue)) - insert_at_elem(p.body, w, reduction) - handle = decls[w.lvalue.symbol] - declaration = Decl(handle.typ, Symbol(name, handle.lvalue.rank), - ArrayInit(np.array([0.0])), handle.qual, handle.attr) - insert_at_elem(parents[index].children, candidate, declaration) - ast_replace(self.stmt, {w.lvalue: reduction.lvalue}, copy=True) - self.hoisted[name] = (reduction, declaration, p, p.body) - - # Pull out the candidate reduction loop - pulling = loops[index + 1:] - pulling = list(zip(*[((l.start, l.end), l.dim) for l in pulling])) - pulling = ItSpace().to_for(*pulling, stmts=[self.stmt]) - insert_at_elem(parents[index].children, candidate, pulling[0], ofs=1) - if len(self.expr_info.parent.children) == 1: - loops[index].body.remove(loops[index + 1]) - else: - self.expr_info.parent.children.remove(self.stmt) - - # Clean up removing any now unnecessary symbols - reads = in_read(candidate, key='symbol') - declarations = Find(Decl, with_parent=True).visit(self.header)[Decl] - declarations = dict(declarations) - for w, p in make_reduce: - if w.lvalue.symbol not in reads: - p.body.remove(w) - if not isinstance(w, Decl): - key = decls[w.lvalue.symbol] - declarations[key].children.remove(key) diff --git a/coffee/optimizer.py b/coffee/optimizer.py index fb4cd231..1a437e00 100644 --- a/coffee/optimizer.py +++ b/coffee/optimizer.py @@ -31,24 +31,7 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import absolute_import, print_function, division -from six.moves import zip - -import operator -import resource -from collections import OrderedDict -from itertools import combinations -from math import factorial as fact - -from . import system -from .base import * -from .utils import * -from .scheduler import ExpressionFissioner, ZeroRemover, SSALoopMerger -from .rewriter import ExpressionRewriter -from .cse import CSEUnpicker -from .logger import warn -from coffee.visitors import Find, ProjectExpansion -from functools import reduce +from .utils import StmtTracker class LoopOptimizer(object): @@ -69,436 +52,6 @@ def __init__(self, loop, header, exprs): # Track hoisted expressions self.hoisted = StmtTracker() - def rewrite(self, mode): - """Rewrite all compute-intensive expressions detected in the loop nest to - minimize the number of floating point operations performed. - - :param mode: Any value in (0, 1, 2, 3, 4). Each ``mode`` corresponds to a - different expression rewriting strategy. - - * mode == 0: no rewriting is performed. - * mode == 1: generalized loop-invariant code motion. - * mode == 2: apply four passes: generalized loop-invariant code motion; - expansion of inner-loop dependent expressions; factorization of - inner-loop dependent terms; generalized loop-invariant code motion. - * mode == 3: apply multiple passes; aims at pre-evaluating sub-expressions - that fully depend on reduction loops. - * mode == 4: rewrite an expression based on its sharing graph - """ - # Set a rewrite mode for each expression - for stmt, expr_info in self.exprs.items(): - expr_info.mode = mode - - # Analyze the individual expressions and try to select an optimal rewrite - # mode for each of them. A preliminary transformation of the loop nest may - # take place in this pass (e.g., injection) - if mode == 'auto': - self._dissect('greedy') - elif mode == 'auto-aggressive': - self._dissect('aggressive') - - # Search for factorization opportunities across temporaries in the kernel - if mode > 1 and self.exprs: - self._unpick_cse() - - # Expression rewriting, expressed as a sequence of AST transformation passes - for stmt, expr_info in self.exprs.items(): - ew = ExpressionRewriter(stmt, expr_info, self.header, self.hoisted) - - if expr_info.mode == 1: - if expr_info.dimension in [0, 1]: - ew.licm(mode='only_outlinear') - else: - ew.licm() - - elif expr_info.mode == 2: - if expr_info.dimension > 0: - ew.replacediv() - ew.sharing_graph_rewrite() - ew.licm(mode='reductions') - - elif expr_info.mode == 3: - ew.expand(mode='all') - ew.factorize(mode='all') - ew.licm(mode='only_const') - ew.factorize(mode='constants') - ew.licm(mode='aggressive') - ew.preevaluate() - ew.factorize(mode='linear') - ew.licm(mode='only_const') - - elif expr_info.mode == 4: - ew.replacediv() - ew.factorize() - ew.licm(mode='only_outlinear') - if expr_info.dimension > 0: - ew.licm(mode='only_linear', iterative=False, max_sharing=True) - ew.sharing_graph_rewrite() - ew.expand() - - # Try merging the loops created by expression rewriting - merged_loops = SSALoopMerger().merge(self.header) - # Update the trackers - for merged, merged_in in merged_loops: - for l in merged: - self.hoisted.update_loop(l, merged_in) - # Was /merged/ an expression loops? If so, need to update the - # corresponding MetaExpr - for stmt, expr_info in self.exprs.items(): - if expr_info.loops[-1] == l: - expr_info._loops_info[-1] = (merged_in, expr_info.loops_parents[-1]) - expr_info._parent = merged_in.children[0] - - # Reduce memory pressure by avoiding useless temporaries - self._min_temporaries() - - # Handle the effects, at the C-level, of the AST transformation - self._recoil() - - def eliminate_zeros(self): - """Restructure the iteration spaces nested in this LoopOptimizer to - avoid evaluation of arithmetic operations involving zero-valued blocks - in statically initialized arrays.""" - - zls = ZeroRemover(self.exprs, self.hoisted) - self.nz_syms = zls.reschedule(self.header) - - def _unpick_cse(self): - """Search for factorization opportunities across temporaries created by - common sub-expression elimination. If a gain in operation count is detected, - unpick CSE and apply factorization + code motion.""" - cse_unpicker = CSEUnpicker(self.exprs, self.header, self.hoisted) - cse_unpicker.unpick() - - def _min_temporaries(self): - """Remove unnecessary temporaries, thus relieving memory pressure. - A temporary is removed iff: - - * it is written once, AND - * it is read once OR it is read n times, but it hosts only a Symbol - """ - - occurrences = count(self.header, mode='symbol_id', read_only=True) - - for l in self.hoisted.all_loops: - info = visit(l, info_items=['symbol_refs', 'symbols_mode']) - to_replace, to_remove = {}, [] - for (temporary, _, _), c in count(l, read_only=True).items(): - if temporary not in self.hoisted: - continue - if self.hoisted[temporary].loop is not l: - continue - if occurrences.get(temporary) != c: - continue - decl = self.hoisted[temporary].decl - place = self.hoisted[temporary].place - expr = self.hoisted[temporary].stmt.rvalue - if c > 1 and explore_operator(expr): - continue - references = info['symbol_refs'][temporary] - syms_mode = info['symbols_mode'] - # Note: only one write is possible at this point - write = [(s, p) for s, p in references if syms_mode[s][0] == WRITE][0] - to_replace[write[0]] = expr - to_remove.append(write[1]) - place.children.remove(decl) - # Update trackers - self.hoisted.pop(temporary) - - # Replace temporary symbols and clean up - l_innermost_body = inner_loops(l)[-1].body - for stmt in l_innermost_body: - if stmt.lvalue in to_replace: - continue - while ast_replace(stmt, to_replace, copy=True): - pass - for stmt in to_remove: - l_innermost_body.remove(stmt) - - def _dissect(self, heuristics): - """Analyze the set of expressions in the LoopOptimizer and infer an - optimal rewrite mode for each of them. - - If an expression is embedded in a non-perfect loop nest, then injection - may be performed. Injection consists of unrolling any loops outside of - the expression iteration space into the expression itself. - For example: :: - - for i - for r - a += B[r]*C[i][r] - for j - for k - A[j][k] += ...f(a)... // the expression at hand - - gets transformed into: - - for i - for j - for k - A[j][k] += ...f(B[0]*C[i][0] + B[1]*C[i][1] + ...)... - - Injection could be necessary to maximize the impact of rewrite mode=3, - which tries to pre-evaluate subexpressions whose values are known at - code generation time. Injection is essential to factorize such subexprs. - - :arg heuristic: any value in ['greedy', 'aggressive']. With 'greedy', a greedy - approach is used to decide which of the expressions for which injection - looks beneficial should be dissected (e.g., injection increases the memory - footprint, and some memory constraints must always be preserved). - With 'aggressive', the whole space of possibilities is analyzed. - """ - # The memory threshold. The total size of temporaries will not have to - # be greated than this value. If we predict that injection will lead - # to too much temporary space, we have to partially drop it - threshold = system.architecture['cache_size'] * 1.2 - - expr_graph = ExpressionGraph(header) - - # 1) Find out and unroll injectable loops. For unrolling we create new - # expressions; that is, for now, we do not modify the AST in place. - analyzed, injectable = [], {} - for stmt, expr_info in self.exprs.items(): - # Get all loop nests, then discard the one enclosing the expression - nests = [n for n in visit(expr_info.loops_parents[0])['fors']] - injectable_nests = [n for n in nests if list(zip(*n))[0] != expr_info.loops] - - for nest in injectable_nests: - to_unroll = [(l, p) for l, p in nest if l not in expr_info.loops] - unroll_cost = reduce(operator.mul, (l.size for l, p in to_unroll)) - - nest_writers = Find(Writer).visit(to_unroll[0][0]) - for op, i_stmts in nest_writers.items(): - # Check safety of unrolling - if op in [Assign, IMul, IDiv]: - continue - assert op in [Incr, Decr] - - for i_stmt in i_stmts: - i_sym, i_expr = i_stmt.children - - # Avoid injecting twice the same loop - if i_stmt in analyzed + [l.incr for l, p in to_unroll]: - continue - analyzed.append(i_stmt) - - # Create unrolled, injectable expressions - for l, p in reversed(to_unroll): - i_expr = [dcopy(i_expr) for i in range(l.size)] - for i, e in enumerate(i_expr): - e_syms = Find(Symbol).visit(e)[Symbol] - for s in e_syms: - s.rank = tuple([r if r != l.dim else i for r in s.rank]) - i_expr = ast_make_expr(Sum, i_expr) - - # Track the unrolled, injectable expressions and their cost - if i_sym.symbol in injectable: - old_i_expr, old_cost = injectable[i_sym.symbol] - new_i_expr = ast_make_expr(Sum, [i_expr, old_i_expr]) - new_cost = unroll_cost + old_cost - injectable[i_sym.symbol] = (new_i_expr, new_cost) - else: - injectable[i_sym.symbol] = (i_expr, unroll_cost) - - # 2) Will rewrite mode=3 be cheaper than rewrite mode=2? - def find_save(target_expr, expr_info): - save_factor = [l.size for l in expr_info.out_linear_loops] or [1] - save_factor = reduce(operator.mul, save_factor) - # The save factor should be multiplied by the number of terms - # that will /not/ be pre-evaluated. To obtain this number, we - # can exploit the linearity of the expression in the terms - # depending on the linear loops. - syms = Find(Symbol).visit(target_expr)[Symbol] - inner = lambda s: any(r == expr_info.linear_dims[-1] for r in s.rank) - nterms = len(set(s.symbol for s in syms if inner(s))) - save = nterms * save_factor - return save_factor, save - - should_unroll = True - storage = 0 - i_syms, injected = injectable.keys(), defaultdict(list) - for stmt, expr_info in self.exprs.items(): - sym, expr = stmt.children - - # Divide /expr/ into subexpressions, each subexpression affected - # differently by injection - if i_syms: - dissected = find_expression(expr, Prod, expr_info.linear_dims, i_syms) - leftover = find_expression(expr, dims=expr_info.linear_dims, out_syms=i_syms) - leftover = {(): list(flatten(leftover.values()))} - dissected = dict(dissected.items() + leftover.items()) - else: - dissected = {(): [expr]} - if any(i not in flatten(dissected.keys()) for i in i_syms): - should_unroll = False - continue - - # Apply the profitability model - analysis = OrderedDict() - for i_syms, target_exprs in dissected.items(): - for target_expr in target_exprs: - - # *** Save *** - save_factor, save = find_save(target_expr, expr_info) - - # *** Cost *** - # The number of operations increases by a factor which - # corresponds to the number of possible /combinations with - # repetitions/ in the injected-values set. We consider - # combinations and not dispositions to take into account the - # (future) effect of factorization. - retval = ProjectExpansion.default_retval() - projection = ProjectExpansion(i_syms).visit(target_expr, ret=retval) - projection = [i for i in projection if i] - increase_factor = 0 - for i in projection: - partial = 1 - for j in expr_graph.shares(i): - # _n=number of unique elements, _k=group size - _n = injectable[j[0]][1] - _k = len(j) - partial *= fact(_n + _k - 1)//(fact(_k)*fact(_n - 1)) - increase_factor += partial - increase_factor = increase_factor or 1 - if increase_factor > save_factor: - # We immediately give up if this holds since it ensures - # that /cost > save/ (but not that cost <= save) - should_unroll = False - continue - # The increase factor should be multiplied by the number of - # terms that will be pre-evaluated. To obtain this number, - # we need to project the output of factorization. - fake_stmt = stmt.__class__(stmt.children[0], dcopy(target_expr)) - fake_parent = expr_info.parent.children - fake_parent[fake_parent.index(stmt)] = fake_stmt - ew = ExpressionRewriter(fake_stmt, expr_info) - ew.expand(mode='all').factorize(mode='all').factorize(mode='linear') - nterms = ew.licm(mode='aggressive', look_ahead=True) - nterms = len(uniquify(nterms[expr_info.dims])) or 1 - fake_parent[fake_parent.index(fake_stmt)] = stmt - cost = nterms * increase_factor - - # Pre-evaluation will also increase the working set size by - # /cost/ * /sizeof(term)/. - size = [l.size for l in expr_info.linear_loops] - size = reduce(operator.mul, size, 1) - storage_increase = cost * size * system.architecture[expr_info.type] - - # Track the injectable sub-expression and its cost/save. The - # final decision of whether to actually perform injection or not - # is postponed until all dissected expressions have been analyzed - analysis[target_expr] = (cost, save, storage_increase) - - # So what should we inject afterall ? Time to *use* the cost model - if heuristics == 'greedy': - for target_expr, (cost, save, storage_increase) in analysis.items(): - if cost > save or storage_increase + storage > threshold: - should_unroll = False - else: - # Update the available storage - storage += storage_increase - # At this point, we can happily inject - to_replace = {k: v[0] for k, v in injectable.items()} - ast_replace(target_expr, to_replace, copy=True) - injected[stmt].append(target_expr) - elif heuristics == 'aggressive': - # A) Remove expression that we already know should never be injected - not_injected = [] - for target_expr, (cost, save, storage_increase) in analysis.items(): - if cost > save: - should_unroll = False - analysis.pop(target_expr) - not_injected.append(target_expr) - # B) Find all possible bipartitions: each bipartition represents - # the set of expressions that will be pre-evaluated and the set - # of expressions that could also be pre-evaluated, but might not - # (e.g. because of memory constraints) - target_exprs = analysis.keys() - bipartitions = [] - for i in range(len(target_exprs)+1): - for e1 in combinations(target_exprs, i): - bipartitions.append((e1, tuple(e2 for e2 in target_exprs - if e2 not in e1))) - # C) Eliminate those bipartitions that would lead to exceeding - # the memory threshold - bipartitions = [(e1, e2) for e1, e2 in bipartitions if - sum(analysis[i][2] for i in e1) <= threshold] - # D) Find out what is best to pre-evaluate (and therefore - # what should be injected) - totals = OrderedDict() - for e1, e2 in bipartitions: - # Is there any value in actually not pre-evaluating the - # expressions in /e2/ ? - fake_expr = ast_make_expr(Sum, list(e2) + not_injected) - _, save = find_save(fake_expr, expr_info) if fake_expr else (0, 0) - cost = sum(analysis[i][0] for i in e1) - totals[(e1, e2)] = save + cost - best = min(totals, key=totals.get) - # At this point, we can happily inject - to_replace = {k: v[0] for k, v in injectable.items()} - for target_expr in best[0]: - ast_replace(target_expr, to_replace, copy=True) - injected[stmt].append(target_expr) - if best[1]: - # At least one non-injected expressions, let's be sure we - # don't unroll everything - should_unroll = False - - # 3) Purge the AST from now useless symbols/expressions - if should_unroll: - decls = visit(self.header, info_items=['decls'])['decls'] - for stmt, expr_info in self.exprs.items(): - nests = [n for n in visit(expr_info.loops_parents[0])['fors']] - injectable_nests = [n for n in nests if list(zip(*n))[0] != expr_info.loops] - for nest in injectable_nests: - unrolled = [(l, p) for l, p in nest if l not in expr_info.loops] - for l, p in unrolled: - p.children.remove(l) - for i_sym in injectable.keys(): - decl = decls.get(i_sym) - if decl and decl in p.children: - p.children.remove(decl) - - # 4) Split the expressions if injection has been performed - for stmt, expr_info in self.exprs.items(): - expr_info.mode = 4 - inj_exprs = injected.get(stmt) - if not inj_exprs: - continue - fissioner = ExpressionFissioner(match=inj_exprs, loops='all', perfect=True) - new_exprs = fissioner.fission(stmt, self.exprs.pop(stmt)) - self.exprs.update(new_exprs) - for stmt, expr_info in new_exprs.items(): - expr_info.mode = 3 if stmt in fissioner.matched else 4 - - def _recoil(self): - """Increase the stack size if the kernel arrays exceed the stack limit - threshold (at the C level).""" - decls = visit(self.header, info_items=['decls'])['decls'] - - # Assume the size of a C type double is 8 bytes - c_double_size = 8 - # Assume the stack size is 1.7 MB (2 MB is usually the limit) - stack_size = 1.7*1024*1024 - - decls = [d for d in decls.values() if d.size] - size = sum([reduce(operator.mul, d.sym.rank) for d in decls]) - - if size * c_double_size > stack_size: - # Increase the stack size if the kernel's stack size seems to outreach - # the space available - try: - resource.setrlimit(resource.RLIMIT_STACK, (resource.RLIM_INFINITY, - resource.RLIM_INFINITY)) - except resource.error: - warn("Stack may blow up, could not increase its size.") - - @property - def expr_loops(self): - """Return ``[(loop1, loop2, ...), ...]``, where each tuple contains all - loops enclosing expressions.""" - return [expr_info.loops for expr_info in self.exprs.values()] - @property def expr_linear_loops(self): """Return ``[(loop1, loop2, ...), ...]``, where each tuple contains all @@ -510,70 +63,7 @@ class CPULoopOptimizer(LoopOptimizer): """Loop optimizer for CPU architectures.""" - def split(self, cut=1): - """Split expressions into multiple chunks exploiting sum's associativity. - Each chunk will have ``cut`` summands. - - For example, consider the following piece of code: :: - - for i - for j - A[i][j] += X[i]*Y[j] + Z[i]*K[j] + B[i]*X[j] - - If ``cut=1`` the expression is cut into chunks of length 1: :: - - for i - for j - A[i][j] += X[i]*Y[j] - for i - for j - A[i][j] += Z[i]*K[j] - for i - for j - A[i][j] += B[i]*X[j] - - If ``cut=2`` the expression is cut into chunks of length 2, plus a - remainder chunk of size 1: :: - - for i - for j - A[i][j] += X[i]*Y[j] + Z[i]*K[j] - # Remainder: - for i - for j - A[i][j] += B[i]*X[j] - """ - - new_exprs = OrderedDict() - elf = ExpressionFissioner(cut=cut, loops='expr') - for stmt, expr_info in self.exprs.items(): - new_exprs.update(elf.fission(stmt, expr_info)) - self.exprs = new_exprs - class GPULoopOptimizer(LoopOptimizer): """Loop optimizer for GPU architectures.""" - - def extract(self): - """Remove the fully-parallel loops of the loop nest. No data dependency - analysis is performed; rather, these are the loops that are marked with - ``pragma coffee itspace``.""" - - info = visit(self.loop, self.header, info_items=['symbols_dep', 'fors']) - symbols = info['symbols_dep'] - - itspace_vrs = set() - for nest in info['fors']: - for loop, parent in reversed(nest): - if '#pragma coffee itspace' not in loop.pragma: - continue - parent = parent.children - for n in loop.body: - parent.insert(parent.index(loop), n) - parent.remove(loop) - itspace_vrs.add(loop.dim) - - accessed_vrs = [s for s in symbols if any_in(s.rank, itspace_vrs)] - - return (itspace_vrs, accessed_vrs) diff --git a/coffee/rewriter.py b/coffee/rewriter.py deleted file mode 100644 index 5a2621f2..00000000 --- a/coffee/rewriter.py +++ /dev/null @@ -1,639 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six.moves import zip - -from collections import Counter -from itertools import combinations -from operator import itemgetter - -from .base import * -from .utils import * -from coffee.visitors import * -from .hoister import Hoister -from .expander import Expander -from .factorizer import Factorizer -from .logger import warn - - -class ExpressionRewriter(object): - """Provide operations to re-write an expression: - - * Loop-invariant code motion: find and hoist sub-expressions which are - invariant with respect to a loop - * Expansion: transform an expression ``(a + b)*c`` into ``(a*c + b*c)`` - * Factorization: transform an expression ``a*b + a*c`` into ``a*(b+c)``""" - - def __init__(self, stmt, expr_info, header=None, hoisted=None): - """Initialize the ExpressionRewriter. - - :param stmt: the node whose rvalue is the expression for rewriting - :param expr_info: ``MetaExpr`` object describing the expression - :param header: the kernel's top node - :param hoisted: dictionary that tracks all hoisted expressions - """ - self.stmt = stmt - self.expr_info = expr_info - self.header = header or Root() - self.hoisted = hoisted if hoisted is not None else StmtTracker() - - self.codemotion = Hoister(self.stmt, self.expr_info, self.header, self.hoisted) - self.expander = Expander(self.stmt) - self.factorizer = Factorizer(self.stmt) - - def licm(self, mode='normal', **kwargs): - """Perform generalized loop-invariant code motion, a transformation - detailed in a paper available at: - - http://dl.acm.org/citation.cfm?id=2687415 - - :param mode: drive code motion by specifying what subexpressions should - be hoisted and where. - * normal: (default) all subexpressions that depend on one loop at most - * aggressive: all subexpressions, depending on any number of loops. - This may require introducing N-dimensional temporaries. - * incremental: apply, in sequence, only_const, only_outlinear, and - one sweep for each linear dimension - * only_const: only all constant subexpressions - * only_linear: only all subexpressions depending on linear loops - * only_outlinear: only all subexpressions independent of linear loops - * reductions: all sub-expressions that are redundantly computed within - a reduction loop; if possible, pull the reduction loop out of - the nest. - :param kwargs: - * look_ahead: (default: False) should be set to True if only a projection - of the hoistable subexpressions is needed (i.e., hoisting not performed) - * max_sharing: (default: False) should be set to True if hoisting should be - avoided in case the same set of symbols appears in different hoistable - sub-expressions. By not hoisting, factorization opportunities are preserved - * iterative: (default: True) should be set to False if interested in - hoisting only the smallest subexpressions matching /mode/ - * lda: an up-to-date loop dependence analysis, as returned by a call - to ``loops_analysis(node, 'dim'). By providing this information, loop - dependence analysis can be avoided, thus speeding up the transformation. - * global_cse: (default: False) search for common sub-expressions across - all previously hoisted terms. Note that no data dependency analysis is - performed, so this is at caller's risk. - * with_promotion: compute hoistable subexpressions within clone loops - even though this doesn't necessarily result in fewer operations. - - Examples - ======== - - 1) With mode='normal': :: - - for i - for j - for k - a[j][k] += (b[i][j] + c[i][j])*(d[i][k] + e[i][k]) - - Redundancies are spotted along both the i and j dimensions, resulting in: :: - - for i - for k - ct1[k] = d[i][k] + e[i][k] - for j - ct2 = b[i][j] + c[i][j] - for k - a[j][k] += ct2*ct1[k] - - 2) With mode='reductions'. - Consider the following loop nest: :: - - for i - for j - a[j] += b[j]*c[i] - - By unrolling the loops, one clearly sees that: :: - - a[0] += b[0]*c[0] + b[0]*c[1] + b[0]*c[2] + ... - a[1] += b[1]*c[0] + b[1]*c[1] + b[1]*c[2] + ... - - Which is identical to: :: - - ct = c[0] + c[1] + c[2] + ... - a[0] += b[0]*ct - a[1] += b[1]*ct - - Thus, the original loop nest is simplified as: :: - - for i - ct += c[i] - for j - a[j] += b[j]*ct - """ - - dimension = self.expr_info.dimension - dims = set(self.expr_info.dims) - linear_dims = set(self.expr_info.linear_dims) - out_linear_dims = set(self.expr_info.out_linear_dims) - - if kwargs.get('look_ahead'): - hoist = self.codemotion.extract - else: - hoist = self.codemotion.licm - - if mode == 'normal': - should_extract = lambda d: d != dims - hoist(should_extract, **kwargs) - elif mode == 'reductions': - should_extract = lambda d: d != dims - # Expansion and reassociation may create hoistable reduction loops - candidates = self.expr_info.reduction_loops - if not candidates: - return self - candidate = candidates[-1] - if candidate.size == 1: - # Otherwise the operation count will just end up increasing - return - self.expand(mode='all') - lda = loops_analysis(self.header, value='dim') - non_candidates = {l.dim for l in candidates[:-1]} - self.reassociate(lambda i: not lda[i].intersection(non_candidates)) - hoist(should_extract, with_promotion=True, lda=lda) - self.codemotion.trim(candidate) - elif mode == 'incremental': - lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') - should_extract = lambda d: not (d and d.issubset(dims)) - hoist(should_extract, lda=lda) - should_extract = lambda d: d.issubset(out_linear_dims) - hoist(should_extract, lda=lda) - for i in range(1, dimension): - should_extract = lambda d: len(d.intersection(linear_dims)) <= i - hoist(should_extract, lda=lda, **kwargs) - elif mode == 'only_const': - should_extract = lambda d: not (d and d.issubset(dims)) - hoist(should_extract, **kwargs) - elif mode == 'only_outlinear': - should_extract = lambda d: d.issubset(out_linear_dims) - hoist(should_extract, **kwargs) - elif mode == 'only_linear': - should_extract = lambda d: not d.issubset(out_linear_dims) and d != linear_dims - hoist(should_extract, **kwargs) - elif mode == 'aggressive': - should_extract = lambda d: True - self.reassociate() - hoist(should_extract, with_promotion=True, **kwargs) - else: - warn('Skipping unknown licm strategy.') - return self - - return self - - def expand(self, mode='standard', **kwargs): - """Expand expressions based on different rules. For example: :: - - (X[i] + Y[j])*F + ... - - can be expanded into: :: - - (X[i]*F + Y[j]*F) + ... - - The expanded term could also be lifted. For example, if we have: :: - - Y[j] = f(...) - (X[i]*Y[j])*F + ... - - where ``Y`` was produced by code motion, expansion results in: :: - - Y[j] = f(...)*F - (X[i]*Y[j]) + ... - - Reasons for expanding expressions include: - - * Exposing factorization opportunities - * Exposing higher level operations (e.g., matrix multiplies) - * Relieving register pressure - - :param mode: multiple expansion strategies are possible - * mode == 'standard': expand along the loop dimension appearing most - often in different symbols - * mode == 'dimensions': expand along the loop dimensions provided in - /kwargs['dimensions']/ - * mode == 'all': expand when symbols depend on at least one of the - expression's dimensions - * mode == 'linear': expand when symbols depend on the expressions's - linear loops. - * mode == 'outlinear': expand when symbols are independent of the - expression's linear loops. - :param kwargs: - * subexprs: an iterator of subexpressions rooted in /self.stmt/. If - provided, expansion will be performed only within these trees, - rather than within the whole expression. - * lda: an up-to-date loop dependence analysis, as returned by a call - to ``loops_analysis(node, 'symbol', 'dim'). By providing this - information, loop dependence analysis can be avoided, thus - speeding up the transformation. - """ - - if mode == 'standard': - symbols = Find(Symbol).visit(self.stmt.rvalue)[Symbol] - # The heuristics privileges linear dimensions - dims = self.expr_info.out_linear_dims - if not dims or self.expr_info.dimension >= 2: - dims = self.expr_info.linear_dims - # Get the dimension occurring most often - occurrences = [tuple(r for r in s.rank if r in dims) for s in symbols] - occurrences = [i for i in occurrences if i] - if not occurrences: - return self - # Finally, establish the expansion dimension - dimension = Counter(occurrences).most_common(1)[0][0] - should_expand = lambda n: set(dimension).issubset(set(n.rank)) - elif mode == 'dimensions': - dimensions = kwargs.get('dimensions', ()) - should_expand = lambda n: set(dimensions).issubset(set(n.rank)) - elif mode in ['all', 'linear', 'outlinear']: - lda = kwargs.get('lda') or loops_analysis(self.expr_info.outermost_loop, - key='symbol', value='dim') - if mode == 'all': - should_expand = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.dims for r in lda[n.symbol]) - elif mode == 'linear': - should_expand = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.linear_dims for r in lda[n.symbol]) - elif mode == 'outlinear': - should_expand = lambda n: lda.get(n.symbol) and \ - not lda[n.symbol].issubset(set(self.expr_info.linear_dims)) - else: - warn('Skipping unknown expansion strategy.') - return - - self.expander.expand(should_expand, **kwargs) - return self - - def factorize(self, mode='standard', **kwargs): - """Factorize terms in the expression. For example: :: - - A[i]*B[j] + A[i]*C[j] - - becomes :: - - A[i]*(B[j] + C[j]). - - :param mode: multiple factorization strategies are possible. Note that - different strategies may expose different code motion opportunities - - * mode == 'standard': factorize symbols along the dimension that appears - most often in the expression. - * mode == 'dimensions': factorize symbols along the loop dimensions provided - in /kwargs['dimensions']/ - * mode == 'all': factorize symbols depending on at least one of the - expression's dimensions. - * mode == 'linear': factorize symbols depending on the expression's - linear loops. - * mode == 'outlinear': factorize symbols independent of the expression's - linear loops. - * mode == 'constants': factorize symbols independent of any loops enclosing - the expression. - * mode == 'adhoc': factorize only symbols in /kwargs['adhoc']/ (details below) - * mode == 'heuristic': no global factorization rule is used; rather, within - each Sum tree, factorize the symbols appearing most often in that tree - :param kwargs: - * subexprs: an iterator of subexpressions rooted in /self.stmt/. If - provided, factorization will be performed only within these trees, - rather than within the whole expression - * adhoc: a list of symbols that can be factorized and, for each symbol, - a list of symbols that can be grouped. For example, if we have - ``kwargs['adhoc'] = [(A, [B, C]), (D, [E, F, G])]``, and the - expression is ``A*B + D*E + A*C + A*F``, the result will be - ``A*(B+C) + A*F + D*E``. If the A's list were empty, all of the - three symbols B, C, and F would be factorized. Recall that this - option is ignored unless ``mode == 'adhoc'``. - * lda: an up-to-date loop dependence analysis, as returned by a call - to ``loops_analysis(node, 'symbol', 'dim'). By providing this - information, loop dependence analysis can be avoided, thus - speeding up the transformation. - """ - - if mode == 'standard': - symbols = Find(Symbol).visit(self.stmt.rvalue)[Symbol] - # The heuristics privileges linear dimensions - dims = self.expr_info.out_linear_dims - if not dims or self.expr_info.dimension >= 2: - dims = self.expr_info.linear_dims - # Get the dimension occurring most often - occurrences = [tuple(r for r in s.rank if r in dims) for s in symbols] - occurrences = [i for i in occurrences if i] - if not occurrences: - return self - # Finally, establish the factorization dimension - dimension = Counter(occurrences).most_common(1)[0][0] - should_factorize = lambda n: set(dimension).issubset(set(n.rank)) - elif mode == 'dimensions': - dimensions = kwargs.get('dimensions', ()) - should_factorize = lambda n: set(dimensions).issubset(set(n.rank)) - elif mode == 'adhoc': - adhoc = kwargs.get('adhoc') - if not adhoc: - return self - should_factorize = lambda n: n.urepr in adhoc - elif mode == 'heuristic': - kwargs['heuristic'] = True - should_factorize = lambda n: False - elif mode in ['all', 'linear', 'outlinear', 'constants']: - lda = kwargs.get('lda') or loops_analysis(self.expr_info.outermost_loop, - key='symbol', value='dim') - if mode == 'all': - should_factorize = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.dims for r in lda[n.symbol]) - elif mode == 'linear': - should_factorize = lambda n: lda.get(n.symbol) and \ - any(r in self.expr_info.linear_dims for r in lda[n.symbol]) - elif mode == 'outlinear': - should_factorize = lambda n: lda.get(n.symbol) and \ - not lda[n.symbol].issubset(set(self.expr_info.linear_dims)) - elif mode == 'constants': - should_factorize = lambda n: not lda.get(n.symbol) - else: - warn('Skipping unknown factorization strategy.') - return - - # Perform the factorization - self.factorizer.factorize(should_factorize, **kwargs) - return self - - def reassociate(self, reorder=None): - """Reorder symbols in associative operations following a convention. - By default, the convention is to order the symbols based on their rank. - For example, the terms in the expression :: - - a*b[i]*c[i][j]*d - - are reordered as :: - - a*d*b[i]*c[i][j] - - This as achieved by reorganizing the AST of the expression. - """ - - def _reassociate(node, parent): - if isinstance(node, (Symbol, Div)): - return - - elif isinstance(node, (Sum, Sub, FunCall, Ternary)): - for n in node.children: - _reassociate(n, node) - - elif isinstance(node, Prod): - children = explore_operator(node) - # Reassociate symbols - symbols = [n for n, p in children if isinstance(n, Symbol)] - # Capture the other children and recur on them - other_nodes = [(n, p) for n, p in children if not isinstance(n, Symbol)] - for n, p in other_nodes: - _reassociate(n, p) - # Create the reassociated product and modify the original AST - children = list(zip(*other_nodes))[0] if other_nodes else () - children += tuple(sorted(symbols, key=reorder)) - reassociated_node = ast_make_expr(Prod, children, balance=False) - parent.children[parent.children.index(node)] = reassociated_node - - else: - warn('Unexpected node %s while reassociating' % typ(node)) - - reorder = reorder if reorder else lambda n: (n.rank, n.dim) - _reassociate(self.stmt.rvalue, self.stmt) - return self - - def replacediv(self): - """Replace divisions by a constant with multiplications.""" - divisions = Find(Div).visit(self.stmt.rvalue)[Div] - to_replace = {} - for i in divisions: - if isinstance(i.right, Symbol): - if isinstance(i.right.symbol, (int, float)): - to_replace[i] = Prod(i.left, 1 / i.right.symbol) - elif isinstance(i.right.symbol, str) and i.right.symbol.isdigit(): - to_replace[i] = Prod(i.left, 1 / float(i.right.symbol)) - else: - to_replace[i] = Prod(i.left, Div(1.0, i.right)) - ast_replace(self.stmt, to_replace, copy=True, mode='symbol') - return self - - def preevaluate(self): - """Preevaluates subexpressions which values are compile-time constants. - In this process, reduction loops might be removed if the reduction itself - could be pre-evaluated.""" - # Aliases - stmt, expr_info = self.stmt, self.expr_info - - # Simplify reduction loops - if not isinstance(stmt, (Incr, Decr, IMul, IDiv)): - # Not a reduction expression, give up - return - expr_syms = Find(Symbol).visit(stmt.rvalue)[Symbol] - reduction_loops = expr_info.out_linear_loops_info - if any([not is_perfect_loop(l) for l, p in reduction_loops]): - # Unsafe if not a perfect loop nest - return - # The following check is because it is unsafe to simplify if non-loop or - # non-constant dimensions are present - hoisted_stmts = self.hoisted.all_stmts - hoisted_syms = [Find(Symbol).visit(h)[Symbol] for h in hoisted_stmts] - hoisted_dims = [s.rank for s in flatten(hoisted_syms)] - hoisted_dims = set([r for r in flatten(hoisted_dims) if not is_const_dim(r)]) - if any(d not in expr_info.dims for d in hoisted_dims): - # Non-loop dimension or non-constant dimension found, e.g. A[i], with /i/ - # not being a loop iteration variable - return - for i, (l, p) in enumerate(reduction_loops): - syms_dep = SymbolDependencies().visit(l, **SymbolDependencies.default_args) - if not all([(tuple(syms_dep[s]) == expr_info.loops and s.dim == len(expr_info.loops)) - for s in expr_syms if syms_dep[s]]): - # A sufficient (although not necessary) condition for loop reduction to - # be safe is that all symbols in the expression are either constants or - # tensors assuming a distinct value in each point of the iteration space. - # So if this condition fails, we give up - return - # At this point, tensors can be reduced along the reducible dimensions - reducible_syms = [s for s in expr_syms if not s.is_const] - # All involved symbols must result from hoisting - if not all([s.symbol in self.hoisted for s in reducible_syms]): - return - # Replace hoisted assignments with reductions - finder = Find(Assign, stop_when_found=True, with_parent=True) - for hoisted_loop in self.hoisted.all_loops: - for assign, parent in finder.visit(hoisted_loop)[Assign]: - sym, expr = assign.children - decl = self.hoisted[sym.symbol].decl - if sym.symbol in [s.symbol for s in reducible_syms]: - parent.children[parent.children.index(assign)] = Incr(sym, expr) - sym.rank = self.expr_info.linear_dims - decl.sym.rank = decl.sym.rank[i+1:] - # Remove the reduction loop - p.children[p.children.index(l)] = l.body[0] - # Update symbols' ranks - for s in reducible_syms: - s.rank = self.expr_info.linear_dims - # Update expression metadata - self.expr_info._loops_info.remove((l, p)) - - # Precompute constant expressions - decls = visit(self.header, info_items=['decls'])['decls'] - evaluator = Evaluate(decls, any(d.nonzero for s, d in decls.items())) - for hoisted_loop in self.hoisted.all_loops: - evals = evaluator.visit(hoisted_loop, **Evaluate.default_args) - # First, find out identical tables - mapper = defaultdict(list) - for s, values in evals.items(): - mapper[str(values)].append(s) - # Then, map identical tables to a single symbol - for values, symbols in mapper.items(): - to_replace = {s: symbols[0] for s in symbols[1:]} - ast_replace(self.stmt, to_replace, copy=True) - # Clean up - for s in symbols[1:]: - s_decl = self.hoisted[s.symbol].decl - self.header.children.remove(s_decl) - self.hoisted.pop(s.symbol) - evals.pop(s) - # Finally, update the hoisted symbols - for s, values in evals.items(): - hoisted = self.hoisted[s.symbol] - hoisted.decl.init = values - hoisted.decl.qual = ['static', 'const'] - self.hoisted.pop(s.symbol) - # Move all decls at the top of the kernel - self.header.children.remove(hoisted.decl) - self.header.children.insert(0, hoisted.decl) - self.header.children.insert(0, FlatBlock("// Preevaluated tables")) - # Clean up - self.header.children.remove(hoisted_loop) - return self - - def sharing_graph_rewrite(self): - """Rewrite the expression based on its sharing graph. Details in the - paper: - - An algorithm for the optimization of finite element integration loops - (Luporini et. al.) - """ - linear_dims = self.expr_info.linear_dims - other_dims = self.expr_info.out_linear_dims - - # Maximize visibility of linear symbols - self.expand(mode='all') - - # Make sure that potential reductions are not hidden away - lda = loops_analysis(self.header, value='dim') - self.reassociate(lambda i: (not lda[i]) + lda[i].issubset(set(other_dims))) - - # Construct the sharing graph - nodes, edges = [], [] - for i in summands(self.stmt.rvalue): - symbols = [i] if isinstance(i, Symbol) else list(zip(*explore_operator(i)))[0] - lsymbols = [s for s in symbols if any(d in lda[s] for d in linear_dims)] - lsymbols = [s.urepr for s in lsymbols] - nodes.extend([j for j in lsymbols if j not in nodes]) - edges.extend(combinations(lsymbols, r=2)) - sgraph = nx.Graph(edges) - - # Transform everything outside the sharing graph (pure linear, no ambiguity) - isolated = [n for n in nodes if n not in sgraph.nodes()] - for n in isolated: - self.factorize(mode='adhoc', adhoc={n: [] for n in nodes}) - self.licm('only_const').licm('only_outlinear') - - # Transform the expression based on the sharing graph - nodes = [n for n in nodes if n in sgraph.nodes()] - if not (nodes and all(sgraph.degree(n) > 0 for n in nodes)): - self.factorize(mode='heuristic') - self.licm('only_const').licm('only_outlinear') - return - # Use short variable names otherwise Pulp might complain - nodes_vars = {i: n for i, n in enumerate(nodes)} - vars_nodes = {n: i for i, n in nodes_vars.items()} - edges = [(vars_nodes[i], vars_nodes[j]) for i, j in edges] - - import pulp as ilp - - def setup(): - # ... declare variables - x = ilp.LpVariable.dicts('x', nodes_vars.keys(), 0, 1, ilp.LpBinary) - y = ilp.LpVariable.dicts('y', - [(i, j) for i, j in edges] + [(j, i) for i, j in edges], - 0, 1, ilp.LpBinary) - limits = defaultdict(int) - for i, j in edges: - limits[i] += 1 - limits[j] += 1 - - # ... define the problem - prob = ilp.LpProblem("Factorizer", ilp.LpMinimize) - - # ... define the constraints - for i in nodes_vars: - prob += ilp.lpSum(y[(i, j)] for j in nodes_vars if (i, j) in y) <= limits[i]*x[i] - - for i, j in edges: - prob += y[(i, j)] + y[(j, i)] == 1 - - # ... define the objective function (min number of factorizations) - prob += ilp.lpSum(x[i] for i in nodes_vars) - - return x, prob - - # Solve the ILP problem to find out the minimal-cost factorization strategy - x, prob = setup() - prob.solve(ilp.GLPK(msg=0)) - - # Also attempt to find another optimal factorization, but with - # additional constraints on the reduction dimensions. This may help in - # later rewrite steps - if len(other_dims) > 1: - z, prob = setup() - for i, n in nodes_vars.items(): - if not set(n[1]).intersection(set(other_dims[:-1])): - prob += z[i] == 0 - prob.solve(ilp.GLPK(msg=0)) - if ilp.LpStatus[prob.status] == 'Optimal': - x = z - - # ... finally, apply the transformations. Observe that: - # 1) the order (first /nodes/, than /other_nodes/) in which - # the factorizations are carried out is crucial - # 2) sorting /nodes/ and /other_nodes/ locally ensures guarantees - # deterministic output code - # 3) precedence is given to outer reduction loops; this maximises the - # impact of later transformations, while not affecting this pass - # 4) with_promotion is set to true if there exist potential reductions - # to simplify - nodes = [nodes_vars[n] for n, v in x.items() if v.value() == 1] - other_nodes = [nodes_vars[n] for n, v in x.items() if nodes_vars[n] not in nodes] - for n in sorted(nodes, key=itemgetter(1)) + sorted(other_nodes): - self.factorize(mode='adhoc', adhoc={n: []}) - self.licm('incremental', with_promotion=len(other_dims) > 1) - - return self diff --git a/coffee/scheduler.py b/coffee/scheduler.py deleted file mode 100644 index 3da18964..00000000 --- a/coffee/scheduler.py +++ /dev/null @@ -1,856 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six import iteritems -from six.moves import range, zip - -from collections import OrderedDict, defaultdict -from itertools import product -from copy import deepcopy as dcopy - -from .base import * -from .utils import * -from .expression import copy_metaexpr -from .rewriter import ExpressionRewriter -from .exceptions import ControlFlowError, UnexpectedNode -from coffee.visitors import FindLoopNests - - -class LoopScheduler(object): - - """Base class for classes that handle loop scheduling; that is, loop fusion, - loop distribution, etc.""" - - -class SSALoopMerger(LoopScheduler): - - """Analyze data dependencies and iteration spaces, then merge fusible loops. - Statements must be in "soft" SSA form: they can be declared and initialized - at declaration time, then they can be assigned a value in only one place.""" - - def _merge_loops(self, root, loop_a, loop_b): - """Merge the body of ``loop_a`` into ``loop_b``.""" - root.children.remove(loop_a) - - dims_a, dims_b = [loop_a.dim], [loop_b.dim] - while isinstance(loop_b.body[0], For): - dims_b.append(loop_b.dim) - loop_b = loop_b.body[0] - while isinstance(loop_a.body[0], For): - dims_a.append(loop_a.dim) - loop_a = loop_a.body[0] - - loop_b.body = loop_a.body + loop_b.body - - ast_update_rank(loop_b, dict(zip(dims_a, dims_b))) - - def _simplify(self, merged_loops): - """Scan the list of merged loops and eliminate sub-expressions that became - duplicate as now iterating along the same iteration space. For example: :: - - for i = 0 to N - A[i] = B[i] + C[i] - for j = 0 to N - D[j] = B[j] + C[j] - - After merging this becomes: :: - - for i = 0 to N - A[i] = B[i] + C[i] - D[i] = B[i] + C[i] - - And finally, after simplification (i.e. after ``simplify`` is applied): :: - - for i = 0 to N - A[i] = B[i] + C[i] - D[i] = A[i] - """ - for loop in merged_loops: - to_replace = {} - for stmt in loop.body: - ast_replace(stmt, to_replace, copy=True) - if not isinstance(stmt, AugmentedAssign): - to_replace[stmt.rvalue] = stmt.lvalue - - def merge(self, root): - """Merge perfect loop nests in ``root``.""" - - # Make sure there are no empty loops within root, otherwise kill them - remove_empty_loops(root) - - expr_graph = ExpressionGraph(root) - - # Collect iteration spaces visiting the tree rooted in /root/ - found_nests = OrderedDict() - for n in root.children: - if isinstance(n, For): - retval = FindLoopNests.default_retval() - loops_infos = FindLoopNests().visit(n, parent=root, ret=retval) - for li in loops_infos: - loops, loops_parents = zip(*li) - # Note that only inner loops can be fused, and that they must - # share the same parent node - key = (tuple(l.header for l in loops), loops_parents[-1]) - found_nests.setdefault(key, []).append(loops[-1]) - - all_merged, merged_loops = [], [] - # A perfect loop nest L1 is mergeable in a loop nest L2 if - # 1 - their iteration space is identical; implicitly true because the keys, - # in the dictionary, are iteration spaces. - # 2 - between the two nests, there are no statements that read/write values - # computed in L1. This is checked later - # 3 - there are no read-after-write dependencies between variables written - # in L1 and read in L2. This is checked later - # In the following, convention is that L2 = /merging_in/, L1 = /l/ - for (itspace, parent), loop_nests in found_nests.items(): - if len(loop_nests) == 1: - # At least two loops are necessary for merging to be meaningful - continue - mergeable = [] - merging_in = loop_nests[-1] - retval = SymbolModes.default_retval() - merging_in_reads = SymbolModes().visit(merging_in.body, ret=retval) - merging_in_reads = [s for s, m in merging_in_reads.items() if m[0] == READ] - for l in loop_nests[:-1]: - is_mergeable = True - # Get the symbols written in /l/ - l_writes = SymbolModes().visit(l.body, ret=SymbolModes.default_retval()) - l_writes = [s for s, m in l_writes.items() if m[0] == WRITE] - - # Check condition 2 - # Get the symbols written between loop /l/ (excluded) and loop - # merging_in (excluded) - bound_left = parent.children.index(l)+1 - bound_right = parent.children.index(merging_in) - for n in parent.children[bound_left:bound_right]: - in_writes = SymbolModes().visit(n, ret=SymbolModes.default_retval()) - in_writes = [s for s, m in in_writes.items()] - for iw, lw in product(in_writes, l_writes): - if expr_graph.is_written(iw, lw): - is_mergeable = False - break - - # Check condition 3 - for lw, mir in product(l_writes, merging_in_reads): - if lw.symbol == mir.symbol and not lw.rank and not mir.rank: - is_mergeable = False - break - - # Track mergeable loops - if is_mergeable: - mergeable.append(l) - - # If there is at least one mergeable loops, do the merging - for l in reversed(mergeable): - self._merge_loops(parent, l, merging_in) - # Update the lists of merged loops - all_merged.append((mergeable, merging_in)) - merged_loops.append(merging_in) - - # Reuse temporaries in merged loops - self._simplify(merged_loops) - - return all_merged - - -class ExpressionFissioner(LoopScheduler): - - """Split expressions embedded in a loop nest.""" - - def __init__(self, **kwargs): - """Initialize the ExpressionFissioner. - - :arg kwargs: - * cut: the number of operands an expression should be fissioned into - * match: a list of subexpressions that should be cut from the input - expression. ``cut`` is ignored if ``match`` is provided. - * loops: a value in ['all', 'expr', 'none']. 'all' means that an - expression is split and its "chunks" are placed in separate loop - nests. 'expr' implies that the chunks are placed within the non - linear loops sorrounding the expression. 'none' means that all - chunks are simply placed within the orginal loop nest - * perfect: if True, create perfect loop nests. This means that any - new loop nest in which a chunk is placed is purged from any extra - statement (apart, obviously, from the chunk itself) - """ - self.cut = kwargs.get('cut', -1) - self.match = [str(i) for i in kwargs.get('match', [])] - self.loops = kwargs.get('loops', 'expr') - self.perfect = kwargs.get('perfect', False) - - if 'match' in kwargs: - self.cutter = self.CutterMatch(self) - elif self.cut > 0: - self.cutter = self.CutterSum(self) - else: - raise RuntimeError("Must specify a `cut` or a `match`.") - - class Cutter(object): - - def __init__(self, expr_fissioner): - self.expr_fissioner = expr_fissioner - - def cut(self, node): - """ - Split ``node`` into /two halves/, called /split/ and /remainder/ - - For example, consider the expression a*b + c*d; if the expression is cut - into chunks containing only one operand (i.e., self.cut=1), then we have - precisely two chunks, /split/ = a*b, /remainder/ = c*d - - If the input expression is a*b + c*d + e*f, and still self.cut=1, then we - have two chunks, /split/ = a*b, /remainder/ = c*d + e*f; that is, - /remainder/ always contains the subexpression after the fission point - """ - self._success = False - left = dcopy(node) - self._cut(left.children[1], left, 'split') - - self._success = False - right = dcopy(node) - self._cut(right.children[1], right, 'remainder') - - return left, right - - class CutterSum(Cutter): - - def _cut(self, node, parent, side, topsum=None): - if isinstance(node, (Symbol, FunCall, Ternary)): - return 0 - - elif isinstance(node, Div): - return self._cut(node.children[0], node, side, topsum) - - elif isinstance(node, Prod): - if topsum: - return 0 - if self._cut(node.left, node, side, topsum) == 0: - return self._cut(node.right, node, side, topsum) - # Prods zero the sum/sub counter - return 0 - - elif isinstance(node, (Sum, Sub)): - topsum = topsum or (parent, parent.children.index(node)) - counter = 1 - counter += self._cut(node.left, node, side, topsum) - counter += self._cut(node.right, node, side, topsum) - if not self._success and counter >= self.expr_fissioner.cut: - # We now are on the topleft sum of this sub-expression such - # that enough sum/sub have been encountered - if not parent: - return 0 - self._success = True - if side == 'split': - topsum[0].children[topsum[1]] = node.left - else: - right = Neg(node.right) if isinstance(node, Sub) else node.right - parent.children[parent.children.index(node)] = right - return counter - else: - return counter - - else: - raise UnexpectedNode("Fission: %s" % str(node)) - - def cut(self, node, expr_info): - left, right = ExpressionFissioner.Cutter.cut(self, node) - if self._success: - index = expr_info.parent.children.index(node) - - # Append /left/ to the original loop nest - expr_info.parent.children[index] = left - split = (left, copy_metaexpr(expr_info)) - - # Append /right/ ... - if self.expr_fissioner.loops in ['expr', 'all']: - # ... in a new loop nest ... - right_info = self.expr_fissioner._embedexpr(right, expr_info) - else: - # ... to the original loop nest - expr_info.parent.children.insert(index, right) - right_info = copy_metaexpr(expr_info) - splittable = (right, right_info) - - return (split, splittable) - return ((node, expr_info), ()) - - class CutterMatch(Cutter): - - def __init__(self, expr_fissioner): - ExpressionFissioner.Cutter.__init__(self, expr_fissioner) - self.matched = [] - - def _cut(self, node, parent, side, topsum=None): - if not self._success and str(node) in self.expr_fissioner.match: - # We initially assume that the found 'match' corresponds - # to the entire node provided as input to the /CutterMatch/. - # Recurring back, we might switch /_success/ to 'match_and_cut', - # if /node/ actually was a summand of a Sum/Sub - self._success = 'match' - return node - - elif isinstance(node, (Symbol, FunCall)): - return None - - elif isinstance(node, Div): - return self._cut(node.left, node, side) - - elif isinstance(node, Prod): - cutting = self._cut(node.left, node, side) - if cutting: - # Found a match /within/ /node.left/; for correctness, we - # need to be sure we will be cutting the whole Prod, so we - # return /node/ instead of /cutting/. - return node - cutting = self._cut(node.right, node, side) - if cutting: - # Same as above - return node - return None - - elif isinstance(node, (Sum, Sub)): - topsum = topsum or (parent, parent.children.index(node)) - # Find out if one of the two children is cuttable - cutting = self._cut(node.left, node, side, topsum) - if cutting and side == 'remainder': - # Need to swap - cutting = node.right - elif not cutting: - cutting = self._cut(node.right, node, side, topsum) - if cutting and side == 'remainder': - # Need to swap - cutting = node.left - if not cutting: - return None - # Adjust if a Sub - if isinstance(node, Sub) and cutting == node.right: - cutting = Neg(cutting) - self._success = 'match_and_cut' - if side == 'split': - # In a tree of Sum/Subs, only the /top/ Sum/Sub performs the - # actual cut, while the others just propagate upwards the - # notification "a cut point was found" - if parent == topsum[0]: - topsum[0].children[topsum[1]] = cutting - return parent - else: - return cutting - else: - parent.children[parent.children.index(node)] = cutting - return None - - else: - raise UnexpectedNode("Fission: %s" % str(node)) - - def cut(self, node, expr_info): - left, right = ExpressionFissioner.Cutter.cut(self, node) - - if self._success == 'match_and_cut': - # Append /left/ to a new loop nest - split = (left, self.expr_fissioner._embedexpr(left, expr_info)) - self.matched.append(left) - - # Append /right/ to the original loop nest - index = expr_info.parent.children.index(node) - expr_info.parent.children[index] = right - splittable = (right, copy_metaexpr(expr_info)) - return (split, splittable) - - elif self._success == 'match': - # A match was actualy found, but there's just nothing to cut - # (i.e., the /match/ is a direct child of /node/) - self.matched.append(node) - - return ((node, expr_info), ()) - - def _embedexpr(self, stmt, expr_info): - """Build a loop nest for ``stmt`` and return its :class:`MetaExpr` object.""" - if self.loops == 'none': - return copy_metaexpr(expr_info) - - # Handle the linear loops - linear_loops = ItSpace(mode=2).to_for(expr_info.linear_loops, stmts=[stmt]) - linear_outerloop = linear_loops[0] - - # Handle the out-linear loops - if self.loops == 'all' and expr_info.out_linear_loops_info: - out_linear_loop, out_linear_loop_parent = expr_info.out_linear_loops_info[0] - index = out_linear_loop.body.index(expr_info.linear_loops[0]) - out_linear_loop = dcopy(out_linear_loop) - if self.perfect: - out_linear_loop.body[:] = [linear_outerloop] - else: - out_linear_loop.body[index] = linear_outerloop - out_linear_loops_info = ((out_linear_loop, out_linear_loop_parent),) - linear_outerloop_parent = out_linear_loop.children[0] - else: - out_linear_loops_info = expr_info.out_linear_loops_info - linear_outerloop_parent = expr_info.linear_loops_parents[0] - - # Build new loops info - finder, env = FindLoopNests(), {'node_parent': linear_outerloop_parent} - loops_info = out_linear_loops_info - loops_info += tuple(finder.visit(linear_outerloop, env=env)[0]) - - # Append the newly created loop nest - if self.loops == 'all' and expr_info.out_linear_loops_info: - expr_info.outermost_parent.children.append(out_linear_loop) - else: - linear_outerloop_parent.children.append(linear_outerloop) - - # Finally, create and return the MetaExpr object - parent = loops_info[-1][0].children[0] - return copy_metaexpr(expr_info, parent=parent, loops_info=loops_info) - - @property - def matched(self): - return self.cutter.matched if self.match else [] - - def fission(self, stmt, expr_info): - """Split, or fission, an expression ``stmt``, whose metadata are provided - through ``expr_info``. - - Return a dictionary mapping expression chunks to :class:`MetaExpr` objects. - - :arg stmt: the expression to be fissioned - :arg expr_info: ``MetaExpr`` object describing ``stmt`` - """ - exprs = OrderedDict() - splittable = (stmt, expr_info) - while splittable: - split, splittable = self.cutter.cut(*splittable) - exprs[split[0]] = split[1] - return exprs - - -class ZeroRemover(LoopScheduler): - - """Analyze data dependencies and iteration spaces to remove arithmetic - operations in loops that iterate over zero-valued blocks. Consequently, - loop nests can be fissioned and/or merged. For example: :: - - for i = 0, N - A[i] = C[i]*D[i] - B[i] = E[i]*F[i] - - If the evaluation of A requires iterating over a block of zero (0.0) values, - because for instance C and D are block-sparse, then A is evaluated in a - different, smaller (i.e., with less iterations) loop nest: :: - - for i = 0 < (N-k) - A[i+k] = C[i+k][i+k] - for i = 0, N - B[i] = E[i]*F[i] - - The implementation is based on symbolic execution. Control flow is not - admitted. - """ - - THRESHOLD = 1 # Only skip if there more than THRESHOLD consecutive zeros - - def __init__(self, exprs, hoisted): - """Initialize the ZeroRemover. - - :param exprs: the expressions for which zero removal is performed. - :param hoisted: dictionary that tracks hoisted sub-expressions - """ - self.exprs = exprs - self.hoisted = hoisted - - def _track_nz_expr(self, node, nz_syms, nest): - """For the expression rooted in ``node``, return iteration space and - offset required to iterate over non zero-valued blocks. For example: :: - - for i = 0 to N - for j = 0 to N - A[i][j] = B[i]*C[j] - - If B along `i` is non-zero in ranges [0, k1] and [k2, k3], while C along - `j` is non-zero in range [N-k4, N], return the intersection of the non-zero - regions as: :: - - [(('i', k1, 0), ('j', N-(N-k4), N-k4))), - (('i', k3-k2, k2), ('j', N-(N-k4), N-k4))] - - That is, for each iteration space variable, return a list of 2-tuples, - in which the first entry represents the size of the iteration space, - and the second entry represents the offset in memory to access the - correct values. - """ - - if isinstance(node, Symbol): - itspace = [] - def_itspace = [tuple((l.dim, Region(l.size, 0)) for l, p in nest)] - nz_bounds = zip(*nz_syms.get(node.symbol, [])) - for i, (r, o, nz_bs) in enumerate(zip(node.rank, node.offset, nz_bounds)): - if o[0] != 1 or isinstance(o[1], str) or is_const_dim(r): - # Cannot handle jumps, non-integer offsets, or constant accesses - continue - try: - # Am I tracking the loop with iteration variable == /r/ ? - loop = [l for l, p in nest if l.dim == r][0] - except IndexError: - # No, so I just assume it covers the entire non zero-valued region - itspace.append([(r, nz_b) for nz_b in nz_bs]) - continue - # Now I can intersect the loop's iteration space with the non - # zero-valued regions - offset = o[1] - r_region = [] - for nz_b in nz_bs: - nz_b_size, nz_b_offset = nz_b - end = nz_b_size + nz_b_offset - start = max(offset, nz_b_offset) - r_offset = start - offset - r_size = max(min(offset + loop.size, end) - start, 0) - r_region.append((r, Region(r_size, r_offset))) - itspace.append(r_region) - itspace = list(zip(*itspace)) or def_itspace - return itspace - - elif isinstance(node, FunCall): - return self._track_nz_expr(node.children[0], nz_syms, nest) - - elif isinstance(node, Ternary): - raise ControlFlowError - - else: - itspace_l = self._track_nz_expr(node.left, nz_syms, nest) - itspace_r = self._track_nz_expr(node.right, nz_syms, nest) - itspace = OrderedDict() - for l in itspace_l: - for i, region in l: - itspace.setdefault(i, []).append(region) - asdict = OrderedDict() - for r in itspace_r: - for i, region in r: - asdict.setdefault(i, []).append(region) - itspace_r = asdict - for i, region in itspace_r.items(): - if i not in itspace: - itspace[i] = region - elif isinstance(node, (Prod, Div)): - result = [] - for j in product(itspace[i], region): - # Products over zero-valued regions are ininfluent - result += [ItSpace(mode=1).intersect(j)] - itspace[i] = result - elif isinstance(node, (Sum, Sub)): - # Sums over zeros remove the zero-valued region (in other words, - # the non zero-valued regions get /merged/) - itspace[i] = ItSpace(mode=1).merge(itspace[i] + region) - else: - raise UnexpectedNode("Zero-avoidance: %s", str(node)) - itspace = list(set(tuple(zip(itspace, i)) - for i in product(*itspace.values()))) - return itspace - - def _track_nz_blocks(self, node, nz_syms, nz_info, nest=None, parent=None, candidates=None): - """Track the propagation of zero-valued blocks in the AST rooted in ``node`` - - ``nz_syms`` contains, for each known identifier, the ranges of - its non zero-valued blocks. For example, assuming identifier A is an - array and has non-zero values in positions [0, k] and [N-k, N], then - ``nz_syms`` will contain an entry {"A": ((0, k), (N-k, N))}. - If A is modified by some statements rooted in ``node``, then - ``nz_syms["A"]`` will be modified accordingly. - - This method also populates ``nz_info``, which maps loop nests to the - enclosed symbols' non-zero blocks. For example, given the following - code: :: - - { // root - ... - for i - for j - A = ... - B = ... - } - - After the traversal of the AST, the ``nz_info`` dictionary will look like: :: - - ((, ), root) -> {A: (i, (nz_along_i)), (j, (nz_along_j))} - - """ - if isinstance(node, Writer): - sym, expr = node.children - - # Outer, non-perfect loops are discarded for transformation safety - # as splitting (a consequence of zero-removal) non-perfect nests is unsafe - nest = tuple([(l, p) for l, p in (nest or []) if is_perfect_loop(l)]) - if not nest: - return - - if nest[-1][0] not in candidates: - return - - # Track the propagation of non zero-valued blocks: ... - # ... within the rvalue - itspaces = self._track_nz_expr(expr, nz_syms, nest) - for i in itspaces: - # ... and then through the lvalue (merging overlaps) - nz_expr = tuple(dict(i).get(r) for r in sym.rank if not is_const_dim(r)) - if any(j is None for j in nz_expr): - break - nz_node = list(nz_syms.setdefault(sym.symbol, [nz_expr])) - if not nz_expr: - continue - merged = False - for e, j in enumerate(nz_node): - # Merging condition: complete overlap in all dimensions but - # the innermost one, for which partial overlap is accepted - inner_merge = ItSpace(mode=1).merge([nz_expr[-1], j[-1]]) - if len(inner_merge) == 1 and \ - all(ItSpace(mode=1).intersect([m, n]) == m for m, n in - zip(nz_expr[:-1], j[:-1])): - nz_syms[sym.symbol][e] = j[:-1] + tuple(inner_merge) - merged = True - break - if not merged: - nz_syms[sym.symbol].append(nz_expr) - - # Record loop nest bounds and memory offsets for /node/ - dims = [l.dim for l, p in nest] - itspaces = [tuple(j for j in i if j[0] in dims) for i in itspaces] - nz_info.setdefault(nest, []).append((node, itspaces)) - - elif isinstance(node, For): - new_nest = (nest or []) + [(node, parent)] - self._track_nz_blocks(node.children[0], nz_syms, nz_info, new_nest, - node, candidates) - - elif isinstance(node, (Root, Block)): - for n in node.children: - self._track_nz_blocks(n, nz_syms, nz_info, nest, node, candidates) - - else: - raise ControlFlowError - - def _reschedule_itspace(self, root, candidates, decls): - """Consider two statements A and B, and their iteration space. If the - two iteration spaces have - - * Same size and same bounds, then put A and B in the same loop nest: :: - - for i, for j - W1[i][j] = W2[i][j] - Z1[i][j] = Z2[i][j] - - * Same size but different bounds, then put A and B in the same loop - nest, but add suitable offsets to all of the involved iteration - variables: :: - - for i, for j - W1[i][j] = W2[i][j] - Z1[i+k][j+k] = Z2[i+k][j+k] - - * Different size, then put A and B in two different loop nests: :: - - for i, for j - W1[i][j] = W2[i][j] - for i, for j // Different loop bounds - Z1[i][j] = Z2[i][j] - - A dictionary describing the structure of the new iteration spaces is - returned. - """ - nz_info = OrderedDict() - - # Compute the initial sparsity pattern of the symbols in /root/ - nz_syms = defaultdict(list) - for s, d in decls.items(): - if not d.nonzero: - continue - for nz_b in product(*d.nonzero): - entries = [list(range(i.ofs, i.ofs + i.size)) for i in nz_b] - if not np.all(d.init.values[np.ix_(*entries)] == 0.0): - nz_syms[s].append(nz_b) - - # Track the propagation of non zero-valued blocks through symbolic - # execution. This populates /nz_info/ and updates /nz_syms/ - try: - self._track_nz_blocks(root, nz_syms, nz_info, candidates=candidates) - except ControlFlowError: - # Couldn't perform symbolic execution due to runtime-dependent data - return nz_syms, OrderedDict() - - # At this point we know where non-zero blocks are located, so we have - # to create proper loop nests to access them - new_exprs, new_nz_info = OrderedDict(), OrderedDict() - for nest, stmt_itspaces in nz_info.items(): - loops, loops_parents = zip(*nest) - fissioned_nests = defaultdict(list) - # Fission the nest to get rid of computation over zero-valued blocks - for stmt, itspaces in stmt_itspaces: - sym, expr = stmt.children - # For each non zero-valued region iterated over... - for i in itspaces: - dim_offset = {d: o for d, (sz, o) in i} - dim_size = tuple(((0, dict(i)[l.dim][0]), l.dim) for l in loops) - # ...add an offset to /stmt/ to access the correct values - new_stmt = ast_update_ofs(dcopy(stmt), dim_offset, increase=True) - # ...add /stmt/ to a new, shorter loop nest - fissioned_nests[dim_size].append((new_stmt, dim_offset)) - # ...initialize arrays to 0.0 for correctness - if sym.symbol in self.hoisted: - self.hoisted[sym.symbol].decl.init = ArrayInit(np.array([0.0])) - # ...track fissioned expressions - if stmt in self.exprs: - new_exprs[new_stmt] = self.exprs[stmt] - # Generate the fissioned loop nests - # Note: the dictionary is sorted because smaller loop nests should - # be executed first, since larger ones depend on them - for dim_size, stmt_dim_offsets in sorted(fissioned_nests.items()): - if all([sz == (0, 0) for sz, dim in dim_size]): - # Discard empty loop nests - continue - # Create the new loop nest ... - new_loops = ItSpace(mode=0).to_for(*zip(*dim_size)) - for stmt, _ in stmt_dim_offsets: - # ... populate it - new_loops[-1].body.append(stmt) - # ... and update tracked data - if stmt in new_exprs: - new_nest = list(zip(new_loops, loops_parents)) - new_exprs[stmt] = copy_metaexpr(new_exprs[stmt], - parent=new_loops[-1].body, - loops_info=new_nest) - self.hoisted.update_stmt(stmt.children[0].symbol, - loop=new_loops[0], - place=loops_parents[0]) - new_nz_info[tuple(new_loops)] = stmt_dim_offsets - # Append the new loops to the root - insert_at_elem(loops_parents[0].children, loops[0], new_loops[0]) - loops_parents[0].children.remove(loops[0]) - - self.exprs.clear() - self.exprs.update(new_exprs) - return nz_syms, new_nz_info - - def _recombine(self, nz_info): - """Recombine expressions writing to the same lvalue.""" - new_exprs = OrderedDict() - ops = {Incr: Sum, Decr: Sub, IMul: Prod} - - for nest, stmt_dim_offsets in nz_info.items(): - mapper = OrderedDict() - for stmt, dim_offsets in stmt_dim_offsets: - sym, expr = stmt.children - if type(stmt) in ops: - # The /key/ means: I'm in the same loop nest, I'm writing to - # the same symbol, and in particular to the same symbol - # locations, and I'm doing an associative AugmentedAssignment. - key = (str(sym), type(stmt)) - mapper.setdefault(key, []).append(stmt) - - for (_, op), stmts in mapper.items(): - exprs = [i.children[1] for i in stmts] - for i in stmts: - nest[-1].body.remove(i) - stmt = op(i.children[0], ast_make_expr(ops[op], exprs)) - nest[-1].body.append(stmt) - # Update the tracked expressions, if necessary - if all(i in self.exprs for i in stmts): - new_exprs[stmt] = self.exprs[i] - - for stmt, expr_info in new_exprs.items(): - ew = ExpressionRewriter(stmt, expr_info) - ew.factorize('heuristic') - - if new_exprs: - self.exprs.clear() - self.exprs.update(new_exprs) - - def _should_skip(self, zero_decls): - """Return False if, based on heuristics, it seems worth skipping the - computation over zeros, True otherwise. True is returned if it - is thought that the implications on low-level performance would be - worse than the gain in operation count (e.g., because spatial locality - within loop would go lost).""" - - if not zero_decls: - return True - - for d in zero_decls: - for d_dim in d.nonzero: - if all(size < ZeroRemover.THRESHOLD for size, offset in d_dim): - return True - - return False - - def reschedule(self, root): - """Restructure the loop nests in ``root`` to avoid computation over - zero-valued data spaces. This is achieved through symbolic execution - starting from ``root``. Control flow, in the form of If, Switch, etc., - is forbidden.""" - decls = visit(root, info_items=['decls'])['decls'] - - # Avoid rescheduling if zero-valued blocks are too small - zero_decls = [d for d in decls.values() if d.nonzero] - if self._should_skip(zero_decls): - return {} - - # Determine the analyzable loops (inner loops in which statements have no - # read-after-write dependencies) - linear_expr_loops = [(l for l in ei.linear_loops) for ei in self.exprs.values()] - linear_expr_loops = set(flatten(linear_expr_loops)) - candidates = [l for l in inner_loops(root) - if not l.is_linear or l in linear_expr_loops] - candidates = [l for l in candidates - if not ExpressionGraph(l.body).has_dependency()] - if not candidates: - return {} - - if linear_expr_loops & set(candidates): - # Split the main expressions to maximize the impact of the rescheduling (this - # helps if different summands have zero-valued blocks at different offsets) - elf = ExpressionFissioner(cut=1, loops='none') - new_exprs = {} - for stmt, expr_info in iteritems(self.exprs): - if expr_info.is_scalar: - new_exprs[stmt] = expr_info - else: - new_exprs.update(elf.fission(stmt, expr_info)) - self.exprs = new_exprs - - # Apply the rescheduling - nz_syms, nz_info = self._reschedule_itspace(root, candidates, decls) - - # Finally, "inline" the expressions that were originally split, if possible - self._recombine(nz_info) - else: - # Apply the rescheduling - nz_syms, nz_info = self._reschedule_itspace(root, candidates, decls) - - return nz_syms diff --git a/coffee/version.py b/coffee/version.py deleted file mode 100644 index 1ca68113..00000000 --- a/coffee/version.py +++ /dev/null @@ -1,4 +0,0 @@ -from __future__ import absolute_import, print_function, division - -__version_info__ = (0, 1, 0) -__version__ = '.'.join(map(str, __version_info__)) From 98d930180088ec13595510a50d0534edfc4aa7be Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 16 Apr 2019 11:56:08 +0100 Subject: [PATCH 3/3] Strip more unused code --- coffee/__init__.py | 154 ------ coffee/exceptions.py | 44 -- coffee/expression.py | 228 --------- coffee/optimizer.py | 69 --- coffee/plan.py | 233 --------- coffee/system.py | 145 ------ coffee/utils.py | 923 ---------------------------------- coffee/vectorizer.py | 740 --------------------------- coffee/visitors/__init__.py | 2 - coffee/visitors/inspectors.py | 685 ------------------------- coffee/visitors/utilities.py | 330 +----------- tests/test_visitors.py | 455 ----------------- 12 files changed, 1 insertion(+), 4007 deletions(-) delete mode 100644 coffee/exceptions.py delete mode 100644 coffee/expression.py delete mode 100644 coffee/optimizer.py delete mode 100644 coffee/plan.py delete mode 100644 coffee/system.py delete mode 100644 coffee/utils.py delete mode 100644 coffee/vectorizer.py delete mode 100644 coffee/visitors/inspectors.py delete mode 100644 tests/test_visitors.py diff --git a/coffee/__init__.py b/coffee/__init__.py index f6a42734..f63eae7d 100644 --- a/coffee/__init__.py +++ b/coffee/__init__.py @@ -30,157 +30,3 @@ # STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division - -import sys - -from coffee.citations import update_citations -from coffee.logger import LOG_DEFAULT, set_log_level, warn -from coffee.system import architecture, compiler, isa -from coffee.system import set_architecture, set_compiler, set_isa - - -__all__ = ['options', 'initialized', 'O0', 'O1', 'O2', 'O3', 'Ov'] - - -class Options(dict): - - def __init__(self): - self._callbacks = {} - - def __setitem__(self, key, value): - super(Options, self).__setitem__(key, value) - self.maybe_update_backend(key, value) - - def register(self, key, value=None, callback=None): - self[key] = value - if callback: - self._callbacks[key] = callback - - def maybe_update_backend(self, key, value): - if key in self._callbacks: - self._callbacks[key](value) - - -class OptimizationLevel(dict): - - _KNOWN = {} - - @classmethod - def retrieve(cls, optlevel): - """Retrieve the set of transformations corresponding to ``optlevel``. - - :param optlevel: may be an :class:`OptimizationLevel` itself (in which - case ``optlevel`` itself is returned) or the name of the level (a string). - """ - if isinstance(optlevel, OptimizationLevel): - return optlevel - elif isinstance(optlevel, str) and optlevel in cls._KNOWN: - return cls._KNOWN[optlevel] - elif not optlevel: - return O0 - else: - warn("Unrecognized optimization specified.") - return O0 - - def __init__(self, name, **kwargs): - self.name = name - - for key, value in kwargs.items(): - self[key] = value - - OptimizationLevel._KNOWN[name] = self - - -def coffee_init(**kwargs): - """Initialize COFFEE. - - :param compiler: Options: ``gnu``, ``intel``. By knowing the backend compiler, - COFFEE can generate specialized code (e.g., by inserting suitable loop pragmas). - :param isa: Options: ``sse``, ``avx``. The instruction set architecture tells - COFFEE the vector length and the available intrinsics so that optimized - vector code (or scalar code suitable for compiler auto-vectorization) is - generated. - :param architecture: Options: ``default``, ``intel``. - :param optlevel: Options: ``O0`` (default), ``O1``, ``O2``, ``O3``, ``Ofast``. - The higher the optimization level, the more aggresively are the transformations. - For more details, refer to the ``set_opt_level``'s documentation. - """ - - global initialized, options - - architecture_id = kwargs.get('architecture', 'default') - compiler_id = kwargs.get('compiler') - isa_id = kwargs.get('isa') - optlevel = kwargs.get('optlevel', O0) - - architecture.clear() - architecture.update(set_architecture(architecture_id)) - - compiler.clear() - compiler.update(set_compiler(compiler_id)) - - isa.clear() - isa.update(set_isa(isa_id)) - - if all([architecture, compiler, isa]): - initialized = True - - options['architecture'] = architecture_id - options['compiler'] = compiler_id - options['isa'] = isa_id - options['optimizations'] = optlevel - - # Allow visits of ASTs that become huge due to transformation. The constant - # /4000/ was empirically found to be an acceptable value - sys.setrecursionlimit(4000) - - -def coffee_reconfigure(**kwargs): - """Reconfigure the internal state of COFFEE.""" - - options['optimizations'] = kwargs.get('optlevel') - - -def set_opt_level(optlevel): - """Set the default optimization level. - - :param optlevel: accepted values are: :: - - ``O0``: No optimizations are applied at all (default). - ``O1``: Apply generalized loop-invariant code motion. Refer to - ``citations.LUPORINI2015`` for more information. - ``O2``: Apply sharing elimination and elimination of useless floating - point operations (e.g., a + 0 == a). Refer to ``citations.LUPORINI2016`` - for more information. - ``O3``: Apply ``O2`` and enforce data alignment through array padding. - This maximizes the impact of compiler auto-vectorization, as thoroughly - discussed in ``citations.LUPORINI2015``. - ``Ofast``: Apply ``O3``, but resort to explicit outer-product vectorization - instead. Vector promotion is also attempted to maximize vectorization in - the outer loops. Refer to ``citations.LUPORINI2015`` for more information. - - Alternatively, one can craft a new composite transformation by creating a - new :class:`OptimizationLevel`. - """ - - optimizations = OptimizationLevel.retrieve(optlevel) - - update_citations(optimizations) - - -O0 = OptimizationLevel('O0') -O1 = OptimizationLevel('O1', rewrite=1) -O2 = OptimizationLevel('O2', rewrite=2, dead_ops_elimination=True) -O3 = OptimizationLevel('O3', align_pad=True, **O2) -Ov = OptimizationLevel('Ov', align_pad=True) - -initialized = False - -options = Options() -options.register('logging', LOG_DEFAULT, set_log_level) -options.register('architecture', callback=set_architecture) -options.register('compiler', callback=set_compiler) -options.register('isa', callback=set_isa) -options.register('optimizations', O0.name, callback=set_opt_level) diff --git a/coffee/exceptions.py b/coffee/exceptions.py deleted file mode 100644 index 434f2ab6..00000000 --- a/coffee/exceptions.py +++ /dev/null @@ -1,44 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division - - -class ControlFlowError(Exception): - """The control flow prevents an AST transformation.""" - pass - - -class UnexpectedNode(Exception): - """A particular node prevents an AST transformation.""" - pass diff --git a/coffee/expression.py b/coffee/expression.py deleted file mode 100644 index 5b09c685..00000000 --- a/coffee/expression.py +++ /dev/null @@ -1,228 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six.moves import zip - -from .utils import * -from collections import OrderedDict - - -class MetaExpr(object): - - """Metadata container for a compute-intensive expression.""" - - def __init__(self, type, parent, loops_info, mode=0): - """Initialize the MetaExpr. - - :param type: the C type of the expression. - :param parent: the node in which the expression is embedded. - :param loops_info: an iterator of 2-tuples; each tuple represents a loop - enclosing the expression (first entry) and its parent (second entry). - :param mode: the suggested rewrite mode. - """ - self._type = type - self._parent = parent - self._loops_info = list(loops_info) - self._mode = mode - - @property - def type(self): - return self._type - - @property - def dims(self): - return tuple(l.dim for l in self.loops) - - @property - def linear_dims(self): - return tuple(l.dim for l in self.linear_loops) - - @property - def out_linear_dims(self): - return tuple(d for d in self.dims if d not in self.linear_dims) - - @property - def reduction_dims(self): - return tuple(l.dim for l in self.reduction_loops) - - @property - def loops(self): - return list(zip(*self._loops_info))[0] - - @property - def loops_from_dims(self): - return OrderedDict(zip(self.dims, self.loops)) - - @property - def loops_parents(self): - return list(zip(*self._loops_info))[1] - - @property - def loops_info(self): - return self._loops_info - - @property - def linear_loops(self): - return tuple([l for l in self.loops if l.is_linear]) - - @property - def linear_loops_parents(self): - return tuple([p for l, p in self._loops_info if l.is_linear]) - - @property - def linear_loops_info(self): - return tuple([(l, p) for l, p in self._loops_info if l.is_linear]) - - @property - def out_linear_loops(self): - return tuple([l for l in self.loops if l not in self.linear_loops]) - - @property - def out_linear_loops_parents(self): - return tuple([p for p in self.loops_parents if p not in self.linear_loops_parents]) - - @property - def out_linear_loops_info(self): - return tuple([i for i in self.loops_info if i not in self.linear_loops_info]) - - @property - def reduction_loops(self): - stmts = Find((Writer, Incr)).visit(self.parent) - if stmts[Incr]: - writers = flatten(stmts.values()) - return tuple(l for l in self.loops - if all(l.dim not in i.lvalue.rank for i in writers)) - else: - return () - - @property - def reduction_loops_parents(self): - retval = self.reduction_loops_info - return zip(*retval)[1] if retval else () - - @property - def reduction_loops_info(self): - return tuple((l, p) for l, p in self.loops_info if l in self.reduction_loops) - - @property - def perfect_loops(self): - """Return the loops in a perfect loop nest for the expression.""" - return [l for l in self.loops if is_perfect_loop(l)] - - @property - def parent(self): - return self._parent - - @property - def outermost_loop(self): - return self.loops[0] if len(self.loops) > 0 else None - - @property - def outermost_parent(self): - return self.loops_parents[0] if len(self.loops_parents) > 0 else None - - @property - def outermost_linear_loop(self): - return self.linear_loops[0] if len(self.linear_loops) > 0 else None - - @property - def outermost_linear_loop_parent(self): - return self.linear_loops_parents[0] if len(self.linear_loops_parents) > 0 else None - - @property - def innermost_loop(self): - return self.loops[-1] if len(self.loops) > 0 else None - - @property - def innermost_parent(self): - return self.loops_parents[-1] if len(self.loops_parents) > 0 else None - - @property - def innermost_linear_loop(self): - return self.linear_loops[-1] if len(self.linear_loops) > 0 else None - - @property - def innermost_linear_loop_parent(self): - return self.linear_loops_parents[-1] if len(self.linear_loops_parents) > 0 else None - - @property - def dimension(self): - return len(self.linear_dims) if not self.is_scalar else 0 - - @property - def is_scalar(self): - return all([isinstance(i, int) for i in self.linear_dims]) - - @property - def is_tensor(self): - return not self.is_scalar - - @property - def is_linear(self): - return self.dimension == 1 - - @property - def is_bilinear(self): - return self.dimension == 2 - - @property - def mode(self): - return self._mode - - @mode.setter - def mode(self, value): - self._mode = value - - -def copy_metaexpr(expr_info, **kwargs): - """Given a ``MetaExpr``, return a plain new ``MetaExpr`` starting from a - copy of ``expr_info``, and replaces some attributes as specified in - ``kwargs``. ``kwargs`` accepts the following keys: parent, loops_info, - mode.""" - - parent = kwargs.get('parent', expr_info.parent) - mode = kwargs.get('mode', expr_info.mode) - - new_loops_info, old_loops_info = [], expr_info.loops_info - to_replace_loops_info = kwargs.get('loops_info', []) - to_replace_loops_info = dict(zip([l.dim for l, p in to_replace_loops_info], - to_replace_loops_info)) - for loop_info in old_loops_info: - loop_dim = loop_info[0].dim - if loop_dim in to_replace_loops_info: - new_loops_info.append(to_replace_loops_info[loop_dim]) - else: - new_loops_info.append(loop_info) - - return MetaExpr(expr_info.type, parent, new_loops_info, mode) diff --git a/coffee/optimizer.py b/coffee/optimizer.py deleted file mode 100644 index 1a437e00..00000000 --- a/coffee/optimizer.py +++ /dev/null @@ -1,69 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from .utils import StmtTracker - - -class LoopOptimizer(object): - - def __init__(self, loop, header, exprs): - """Initialize the LoopOptimizer. - - :param loop: root AST node of a loop nest - :param header: the kernel's top node - :param exprs: list of expressions to be optimized - """ - self.loop = loop - self.header = header - self.exprs = exprs - - # Track nonzero regions accessed in each symbol - self.nz_syms = {} - # Track hoisted expressions - self.hoisted = StmtTracker() - - @property - def expr_linear_loops(self): - """Return ``[(loop1, loop2, ...), ...]``, where each tuple contains all - linear loops enclosing expressions.""" - return [expr_info.linear_loops for expr_info in self.exprs.values()] - - -class CPULoopOptimizer(LoopOptimizer): - - """Loop optimizer for CPU architectures.""" - - -class GPULoopOptimizer(LoopOptimizer): - - """Loop optimizer for GPU architectures.""" diff --git a/coffee/plan.py b/coffee/plan.py deleted file mode 100644 index 218508cb..00000000 --- a/coffee/plan.py +++ /dev/null @@ -1,233 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -"""COFFEE's optimization plan constructor.""" -from __future__ import absolute_import, print_function, division - -import coffee -from .base import * -from .utils import * -from .optimizer import CPULoopOptimizer, GPULoopOptimizer -from .vectorizer import LoopVectorizer, VectStrategy -from .expression import MetaExpr -from .logger import log, warn, PERF_OK, PERF_WARN -from coffee.visitors import Find, EstimateFlops - -from collections import OrderedDict -import time - - -class ASTKernel(object): - - """Manipulate the kernel's Abstract Syntax Tree.""" - - def __init__(self, ast, include_dirs=None): - self.ast = ast - self.include_dirs = include_dirs or [] - - def plan_cpu(self, opts): - """Optimize this :class:`ASTKernel` for CPU execution. - - :param opts: a dictionary of optimizations to be applied. For a description - of the recognized optimizations, please refer to the ``coffee.set_opt_level`` - documentation. If equal to ``None``, the default optimizations in - ``coffee.options['optimizations']`` are applied; these are either the - optimizations set when COFFEE was initialized or those changed through - a call to ``set_opt_level``. In this way, a default set of optimizations - is applied to all kernels, but users are also allowed to select - specific transformations for individual kernels. - """ - - start_time = time.time() - - kernels = Find(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] - - if opts is None: - opts = coffee.OptimizationLevel.retrieve(coffee.options['optimizations']) - else: - opts = coffee.OptimizationLevel.retrieve(opts.get('optlevel', {})) - - flops_pre = EstimateFlops().visit(self.ast) - - for kernel in kernels: - rewrite = opts.get('rewrite') - vectorize = opts.get('vectorize', (None, None)) - align_pad = opts.get('align_pad') - split = opts.get('split') - dead_ops_elimination = opts.get('dead_ops_elimination') - - info = visit(kernel, info_items=['decls', 'exprs']) - # Collect expressions and related metadata - nests = OrderedDict() - for stmt, expr_info in info['exprs'].items(): - parent, nest = expr_info - if not nest: - continue - if kernel.template: - typ = "double" - else: - typ = check_type(stmt, info['decls']) - metaexpr = MetaExpr(typ, parent, nest) - nests.setdefault(nest[0], OrderedDict()).update({stmt: metaexpr}) - loop_opts = [CPULoopOptimizer(loop, header, exprs) - for (loop, header), exprs in nests.items()] - - # Combining certain optimizations is forbidden. - if dead_ops_elimination and split: - warn("Split forbidden with dead-ops elimination") - return - if dead_ops_elimination and vectorize[0]: - warn("Vect forbidden with dead-ops elimination") - return - if rewrite == 'auto' and len(info['exprs']) > 1: - warn("Rewrite auto forbidden with multiple exprs") - rewrite = 4 - - # Main Ootimization pipeline - for loop_opt in loop_opts: - - # 0) Expression Rewriting - if rewrite: - loop_opt.rewrite(rewrite) - - # 1) Dead-operations elimination - if dead_ops_elimination: - loop_opt.eliminate_zeros() - - # 2) Code specialization - if split: - loop_opt.split(split) - if coffee.initialized and flatten(loop_opt.expr_linear_loops): - vect = LoopVectorizer(loop_opt, kernel) - if align_pad: - # Padding and data alignment - vect.autovectorize() - if vectorize[0] and vectorize[0] != VectStrategy.AUTO: - # Specialize vectorization for the memory access pattern - # of the expression - vect.specialize(*vectorize) - - # Ensure kernel is always marked static inline - # Remove either or both of static and inline (so that we get the order right) - kernel.pred = [q for q in kernel.pred if q not in ['static', 'inline']] - kernel.pred.insert(0, 'inline') - kernel.pred.insert(0, 'static') - - # Post processing of the AST ensures higher-quality code - postprocess(kernel) - - flops_post = EstimateFlops().visit(self.ast) - - tot_time = time.time() - start_time - - output = "COFFEE finished in %g seconds (flops: %d -> %d)" % \ - (tot_time, flops_pre, flops_post) - log(output, PERF_OK if flops_post <= flops_pre else PERF_WARN) - - def plan_gpu(self): - """Transform the kernel suitably for GPU execution. - - Loops decorated with a ``pragma coffee itspace`` are hoisted out of - the kernel. The list of arguments in the function signature is - enriched by adding iteration variables of hoisted loops. The size of any - kernel's non-constant tensor is modified accordingly. - - For example, consider the following function: :: - - void foo (int A[3]) { - int B[3] = {...}; - #pragma coffee itspace - for (int i = 0; i < 3; i++) - A[i] = B[i]; - } - - plan_gpu modifies its AST such that the resulting output code is :: - - void foo(int A[1], int i) { - A[0] = B[i]; - } - """ - - # The optimization passes are performed individually (i.e., "locally") for - # each function (or "kernel") found in the provided AST - kernels = Find(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] - - for kernel in kernels: - info = visit(kernel, info_items=['decls', 'exprs']) - # Collect expressions and related metadata - nests = OrderedDict() - for stmt, expr_info in info['exprs'].items(): - parent, nest = expr_info - if not nest: - continue - if kernel.template: - typ = "double" - else: - typ = check_type(stmt, info['decls']) - metaexpr = MetaExpr(typ, parent, nest) - nests.setdefault(nest[0], OrderedDict()).update({stmt: metaexpr}) - loop_opts = [GPULoopOptimizer(loop, header, exprs) - for (loop, header), exprs in nests.items()] - - for loop_opt in loop_opts: - itspace_vrs, accessed_vrs = loop_opt.extract() - - for v in accessed_vrs: - # Change declaration of non-constant iteration space-dependent - # parameters by shrinking the size of the iteration space - # dimension to 1 - decl = set( - [d for d in kernel.args if d.sym.symbol == v.symbol]) - dsym = decl.pop().sym if len(decl) > 0 else None - if dsym and dsym.rank: - dsym.rank = tuple([1 if i in itspace_vrs else j - for i, j in zip(v.rank, dsym.rank)]) - - # Remove indices of all iteration space-dependent and - # kernel-dependent variables that are accessed in an itspace - v.rank = tuple([0 if i in itspace_vrs and dsym else i - for i in v.rank]) - - # Add iteration space arguments - kernel.args.extend([Decl("int", Symbol("%s" % i)) for i in itspace_vrs]) - - # Clean up the kernel removing variable qualifiers like 'static' - for decl in decls.values(): - d, place = decl - d.qual = [q for q in d.qual if q not in ['static', 'const']] - - kernel.pred = [q for q in kernel.pred if q not in ['static', 'inline']] - - def gencode(self): - """Generate a string representation of the AST.""" - return self.ast.gencode() diff --git a/coffee/system.py b/coffee/system.py deleted file mode 100644 index ac4500d0..00000000 --- a/coffee/system.py +++ /dev/null @@ -1,145 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2016, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Provide mechanisms to initialize COFFEE or change its state.""" - -from __future__ import absolute_import, print_function, division - -from coffee.base import * - - -__all__ = ['architecture', 'compiler', 'isa'] - - -def set_architecture(architecture_id): - """Set architecture-specific parameters. Supported architectures: - - * 'default'/'intel': a conventional multi-core CPU, such as an Intel Haswell - """ - - # All sizes are in Bytes - # The /cache_size/ is the size of the private memory closest to a core - - if architecture_id in ['default', 'intel']: - return { - 'cache_size': 256 * 10**3, - 'double': 8 - } - - return {} - - -def set_compiler(compiler_id): - """Set compiler-specific keywords. Supported compilers: - - * 'gnu' (aka gcc) - * 'intel' (aka icc) - """ - - if compiler_id == 'gnu': - return { - 'name': 'gnu', - 'cmd': 'gcc', - 'align': lambda o: '__attribute__((aligned(%s)))' % o, - 'align_forloop': '', - 'force_simdization': '', - 'AVX': '-mavx', - 'SSE': '-msse', - 'ipo': '', - 'native_opt': '-mtune=native', - 'vect_header': '#include ' - } - - if compiler_id == 'intel': - return { - 'name': 'intel', - 'cmd': 'icc', - 'align': lambda o: '__attribute__((aligned(%s)))' % o, - 'align_forloop': '#pragma vector aligned', - 'force_simdization': '#pragma simd', - 'AVX': '-xAVX', - 'SSE': '-xSSE', - 'ipo': '-ip', - 'native_opt': '-xHost', - 'vect_header': '#include ' - } - - return {} - - -def set_isa(isa_id): - """Set the instruction set architecture (ISA). Supported ISAs: - - * 'sse' - * 'avx' - """ - - if isa_id == 'sse': - return { - 'inst_set': 'SSE', - 'avail_reg': 16, - 'alignment': 16, - 'dp_reg': 2, # Number of values in double precision per register - 'reg': lambda n: 'xmm%s' % n - } - - if isa_id == 'avx': - return { - 'inst_set': 'AVX', - 'avail_reg': 16, - 'alignment': 32, - 'dp_reg': 4, # Number of values in double precision per register - 'reg': lambda n: 'ymm%s' % n, - 'zeroall': '_mm256_zeroall ()', - 'setzero': AVXSetZero(), - 'decl_var': '__m256d', - 'align_array': lambda p: '__attribute__((aligned(%s)))' % p, - 'symbol_load': lambda s, r, o=None: AVXLoad(s, r, o), - 'symbol_set': lambda s, r, o=None: AVXSet(s, r, o), - 'store': lambda m, r: AVXStore(m, r), - 'mul': lambda r1, r2: AVXProd(r1, r2), - 'div': lambda r1, r2: AVXDiv(r1, r2), - 'add': lambda r1, r2: AVXSum(r1, r2), - 'sub': lambda r1, r2: AVXSub(r1, r2), - 'l_perm': lambda r, f: AVXLocalPermute(r, f), - 'g_perm': lambda r1, r2, f: AVXGlobalPermute(r1, r2, f), - 'unpck_hi': lambda r1, r2: AVXUnpackHi(r1, r2), - 'unpck_lo': lambda r1, r2: AVXUnpackLo(r1, r2) - } - - return {} - - -architecture = {} -compiler = {} -isa = {} # Instruction Set Architecture diff --git a/coffee/utils.py b/coffee/utils.py deleted file mode 100644 index 88985ba4..00000000 --- a/coffee/utils.py +++ /dev/null @@ -1,923 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Utility functions for the inspection, transformation, and creation of ASTs.""" - -from __future__ import absolute_import, print_function, division -from six import iterkeys, iteritems -from six.moves import zip - -from copy import deepcopy as dcopy -from collections import defaultdict, OrderedDict, namedtuple - -from coffee.base import * -from coffee.visitors.inspectors import * -from coffee.visitors.utilities import Reconstructor - - -##################################### -# Functions to manipulate AST nodes # -##################################### - - -def ast_replace(node, to_replace, copy=True, mode='all'): - """ - Given the ``to_replace`` dictionary ``{k: v}``, replace each - ``k`` rooted in ``node`` with ``v``. - - :param copy: pass False to avoid reconstructing ``v`` each time ``k`` - is encountered. - :param mode: either ``all``, in which case ``to_replace``'s keys are turned - into strings, and all of the occurrences are removed from the - AST; or ``symbol``, in which case only (all of) the references - to the symbols given in ``to_replace`` are replaced. - """ - assert mode in ['all', 'symbol'] - - if mode == 'all': - to_replace = {str(k): v for k, v in to_replace.items()} - should_replace = lambda n: to_replace.get(str(n)) - else: - should_replace = lambda n: to_replace.get(n) - - replacements = [] - - def _ast_replace(node, to_replace): - replaced_children = {} - for i, n in enumerate(node.children): - v = should_replace(n) - if v: - replaced_children[i] = ast_reconstruct(v) if copy else v - replacements.append(replaced_children[i]) - else: - _ast_replace(n, to_replace) - for i, r in replaced_children.items(): - node.children[i] = r - - _ast_replace(node, to_replace) - - return replacements - - -def ast_reconstruct(node): - """Recursively reconstruct ``node``.""" - return Reconstructor().visit(node) - - -def ast_update_ofs(node, ofs, **kwargs): - """Change the offsets of the iteration space variables of the symbols rooted - in ``node``. - - :arg node: root AST node - :arg ofs: a dictionary ``{'dim': value}``; `dim`'s offset is changed to `value` - :arg kwargs: optional parameters to drive the transformation: - * increase: `value` is added to the pre-existing offset, not substituted - """ - increase = kwargs.get('increase', False) - - symbols = Find(Symbol).visit(node)[Symbol] - for s in symbols: - new_offset = [] - for r, o in zip(s.rank, s.offset): - if increase: - val = ofs.get(r, 0) - if isinstance(o[1], str) or isinstance(val, str): - new_o = "%s + %s" % (o[1], val) - else: - new_o = o[1] + val - else: - new_o = ofs.get(r, o[1]) - new_offset.append((o[0], new_o)) - s.offset = tuple(new_offset) - - return node - - -def ast_update_rank(node, mapper): - """Change the rank of the symbols rooted in ``node`` as prescribed by - ``rank``. - - :arg node: Root AST node - :arg mapper: Describe how to change the rank of a symbol. For example, - if mapper={'i': 'j'} and node='A[i] = B[i]', then node will be - transformed into 'A[j] = B[j]' - """ - - for s in Find(Symbol).visit(node)[Symbol]: - s.rank = tuple([r if r not in mapper else mapper[r] for r in s.rank]) - - -############################################### -# Functions to simplify creation of AST nodes # -############################################### - - -def ast_make_for(stmts, loop, copy=False): - """Create a for loop having the same iteration space as ``loop`` enclosing - the statements in ``stmts``. If ``copy == True``, then new instances of - ``stmts`` are created""" - wrap = Block(dcopy(stmts) if copy else stmts, open_scope=True) - new_loop = For(dcopy(loop.init), dcopy(loop.cond), dcopy(loop.incr), - wrap, dcopy(loop.pragma)) - return new_loop - - -def ast_make_expr(op, nodes, balance=True): - """Create an ``Expr`` Node of type ``op``, with children given in ``nodes``.""" - - def _ast_make_expr(nodes): - return nodes[0] if len(nodes) == 1 else op(nodes[0], _ast_make_expr(nodes[1:])) - - def _ast_make_bal_expr(nodes): - half = len(nodes) // 2 - return nodes[0] if len(nodes) == 1 else op(_ast_make_bal_expr(nodes[:half]), - _ast_make_bal_expr(nodes[half:])) - - if len(nodes) == 0: - return None - elif balance: - return _ast_make_bal_expr(nodes) - else: - return _ast_make_expr(nodes) - - -def ast_make_alias(node, alias_name): - """ - Create an alias of ``node`` (must be of type Decl). The alias symbol is - given the name ``alias_name``. For example: :: - - (node, alias_name) --> output - (double * a, b) --> double * b = a - (double a[1], b) --> double * b = a - (double a[1][1], b) --> double (*b)[1] = a - """ - assert isinstance(node, Decl) - - pointers = list(node.pointers) - if len(node.size) == 1: - pointers += [''] - if len(node.size) > 1: - symbol = SymbolIndirection(alias_name, node.size[1:]) - else: - symbol = Symbol(alias_name, node.size[1:]) - - return Decl(node.typ, symbol, node.lvalue.symbol, qualifiers=node.qual, - scope=node.scope, pointers=pointers) - - -########################################################### -# Functions to visit and to query properties of AST nodes # -########################################################### - - -def visit(node, parent=None, info_items=None): - """Explore the AST rooted in ``node`` and collect various info, including: - - * Loop nests encountered - a list of tuples, each tuple representing a loop nest - * Declarations - a dictionary {variable name (str): declaration (AST node)} - * Symbols (dependencies) - a dictionary {symbol (AST node): [loops] it depends on} - * Symbols (access mode) - a dictionary {symbol (AST node): access mode (WRITE, ...)} - * String to Symbols - a dictionary {symbol (str): [(symbol, parent) (AST nodes)]} - * Expressions - mathematical expressions to optimize (decorated with a pragma) - - :param node: AST root node of the visit - :param parent: parent node of ``node`` - :param info_items: An optional list of information to gather, - valid values are:: - - - "symbols_dep" - - "decls" - - "exprs" - - "fors" - - "symbol_refs" - - "symbols_mode" - """ - info = {} - - if info_items is None: - info_items = ['decls', 'symbols_dep', 'symbol_refs', - 'symbols_mode', 'exprs', 'fors'] - if 'decls' in info_items: - retval = SymbolDeclarations.default_retval() - info['decls'] = SymbolDeclarations().visit(node, ret=retval) - - if 'symbols_dep' in info_items: - deps = SymbolDependencies().visit(node, ret=SymbolDependencies.default_retval(), - **SymbolDependencies.default_args) - # Prune access mode: - for k in list(iterkeys(deps)): - if type(k) is not Symbol: - del deps[k] - info['symbols_dep'] = deps - - if 'exprs' in info_items: - retval = FindCoffeeExpressions.default_retval() - info['exprs'] = FindCoffeeExpressions().visit(node, parent=parent, ret=retval) - - if 'fors' in info_items: - retval = FindLoopNests.default_retval() - info['fors'] = FindLoopNests().visit(node, parent=parent, ret=retval) - - if 'symbol_refs' in info_items: - retval = SymbolReferences.default_retval() - info['symbol_refs'] = SymbolReferences().visit(node, parent=parent, ret=retval) - - if 'symbols_mode' in info_items: - retval = SymbolModes.default_retval() - info['symbols_mode'] = SymbolModes().visit(node, parent=parent, ret=retval) - - return info - - -def loops_analysis(node, key='default', value='default'): - """Perform loop dependence analysis in the AST rooted in ``node``. Return - a dictionary mapping symbols to loops they depend on. - - :arg key: any value in ['default', 'urepr', 'symbol']. With 'urepr' and - 'symbol' different instances of the same Symbol are represented by - a single entry in the returned dictionary. - :arg value: any value in ['default', 'dim']. If 'dim' is specified, then - loop iteration dimensions are used in place of the actual object. - """ - symbols_dep = visit(node, info_items=['symbols_dep'])['symbols_dep'] - - if key == 'default': - gen_key = lambda s: s - elif key == 'urepr': - gen_key = lambda s: s.urepr - elif key == 'symbol': - gen_key = lambda s: s.symbol - else: - raise RuntimeError("Illegal key=%s for loop dependence analysis" % key) - - if value == 'default': - lda = defaultdict(list) - update = lambda i, dep: i.extend(list(dep)) - elif value == 'dim': - lda = defaultdict(set) - update = lambda i, dep: i.update({j.dim for j in dep}) - else: - raise RuntimeError("Illegal value=%s for loop dependence analysis" % value) - - for s, dep in symbols_dep.items(): - update(lda[gen_key(s)], dep) - - return lda - - -def reachability_analysis(node): - """ - Perform reachability analysis in the AST rooted in ``node``. Return - a dictionary mapping symbols to scopes in which they are visible. - """ - return SymbolVisibility().visit(node)[0] - - -def explore_operator(node): - """Return a list of the operands composing the operation whose root is - ``node``.""" - - def _explore_operator(node, operator, children): - for n in node.children: - if n.__class__ == operator: - _explore_operator(n, operator, children) - else: - children.append((n, node)) - - children = [] - _explore_operator(node, node.__class__, children) - return children - - -def inner_loops(node): - """Find inner loops in the subtree rooted in ``node``.""" - - return FindInnerLoops().visit(node) - - -def is_perfect_loop(loop): - """Return True if ``loop`` is part of a perfect loop nest, False otherwise.""" - - return CheckPerfectLoop().visit(loop) - - -def in_written(node, key='default'): - """Return a list of symbols written in ``node``. - - :arg key: any value in ['default', 'urepr', 'symbol']. With 'urepr' and - 'symbol' different instances of the same Symbol are represented by - a single entry in the returned dictionary. - """ - - if key == 'default': - gen_key = lambda s: s - elif key == 'urepr': - gen_key = lambda s: s.urepr - elif key == 'symbol': - gen_key = lambda s: s.symbol - else: - raise RuntimeError("Illegal key=%s for in_written" % key) - - found = [] - writers = Find(Writer).visit(node) - for type, stmts in writers.items(): - for stmt in stmts: - found.append(gen_key(stmt.lvalue)) - - return found - - -def in_read(node, key='default'): - """ - Return a list of symbols read in ``node``. - - :arg key: any value in ['default', 'urepr', 'symbol']. With 'urepr' and - 'symbol' different instances of the same Symbol are represented by - a single entry in the returned dictionary. - """ - - if key == 'default': - gen_key = lambda s: s - elif key == 'urepr': - gen_key = lambda s: s.urepr - elif key == 'symbol': - gen_key = lambda s: s.symbol - else: - raise RuntimeError("Illegal key=%s for in_read" % key) - - found = [] - writers = Find(Writer).visit(node) - for type, stmts in writers.items(): - for stmt in stmts: - reads = Find(Symbol).visit(stmt.rvalue)[Symbol] - found.extend([gen_key(s) for s in reads]) - - return found - - -def count(node, mode='urepr', read_only=False): - """Count the occurrences of all variables appearing in ``node``. For example, - for the expression: :: - - ``a*(5+c) + b*(a+4)`` - - return :: - - ``{a: 2, b: 1, c: 1}`` - - :param node: The root of the AST visited - :param mode: Set the key in the returned dictionary. Accepted values - are ['urepr', 'symbol_id'], where: - * mode == 'urepr': (default) use the symbol representation as key - * mode == 'symbol_id': use the symbol name as key, thus ignoring - any iteration space or offset. For example, if both A[0] and A[i] - appear in ``node``, return {A: 2, ...} (assuming no other - occurrences of A) - :param read_only: True if only variables on the right hand side of a statement - should be counted; False if any appearance should be counted. - """ - modes = ['urepr', 'symbol_id'] - if mode == 'urepr': - key = lambda n: n.urepr - elif mode == 'symbol_id': - key = lambda n: n.symbol - else: - raise RuntimeError("`Count` function got a wrong mode (valid: %s)" % modes) - - v = CountOccurences(key=key, only_rvalues=read_only) - return v.visit(node, ret=v.default_retval()) - - -def check_type(stmt, decls): - """Check the types of the ``stmt``'s LHS and RHS. If they match as expected, - return the type itself. Otherwise, an error is generated, suggesting an issue - in either the AST itself (i.e., a bug inherent the AST) or, possibly, in the - optimization process. - - :param stmt: the AST node statement to be checked - :param decls: a dictionary from symbol identifiers (i.e., strings representing - the name of a symbol) to Decl nodes - """ - v = SymbolReferences() - lhs_symbol, = iterkeys(v.visit(stmt.lvalue, parent=stmt, ret=v.default_retval())) - rhs_symbols = iterkeys(v.visit(stmt.rvalue, parent=stmt, ret=v.default_retval())) - - lhs_decl = decls[lhs_symbol] - rhs_decls = [decls[s] for s in rhs_symbols if s in decls] - - type = lambda d: d.typ.replace('*', '') - if any([type(lhs_decl) != type(rhs_decl) for rhs_decl in rhs_decls]): - raise RuntimeError("Non matching types in %s" % str(stmt)) - - return type(lhs_decl) - - -def find_expression(node, type=None, dims=None, in_syms=None, out_syms=None): - """Wrapper of the FindExpression visitor.""" - finder = FindExpression(type, dims, in_syms, out_syms) - exprs = finder.visit(node, ret=FindExpression.default_retval()) - if 'cleaned' in exprs: - exprs.pop('cleaned') - if 'in_syms' in exprs: - exprs.pop('in_syms') - if 'out_syms' in exprs: - exprs.pop('out_syms') - if 'inner_syms' in exprs: - exprs.pop('inner_syms') - if 'in_itspace' in exprs: - exprs.pop('in_itspace') - return exprs - - -def summands(node): - """ - Return the top-level summands in /node/. - - Examples - ======== - - a + b --> [a, b] - a*b*c --> [a*b*c] - a*b*c + c*d --> [a*b*c, c*d] - (a+b)*c + d --> [(a+b)*c, d] - foo(a) --> [] - """ - - handle = list(zip(*explore_operator(node))) - if not handle: - return [] - operands, parents = handle - if all(isinstance(p, Sum) for p in parents): - return operands - elif all(isinstance(p, Prod) for p in parents): - # Single top-level summand - return [node] - else: - return [] - - -####################################################################### -# Functions to manipulate iteration spaces in various representations # -####################################################################### - - -class ItSpace(object): - - """A collection of routines to manipulate iteration spaces.""" - - def __init__(self, mode=0): - """Initialize an ItSpace object. - - :arg mode: Establish how an interation space is represented. - :type mode: integer, allowed [0 (default), 1, 2]; respectively, an - iteration space is represented as: - * 0: a 2-tuple indicating the bounds of the accessed region - * 1: a 2-tuple indicating size and offset of the accessed region - * 2: a For loop object - """ - assert mode in [0, 1, 2], "Invalid mode for ItSpace()" - self.mode = mode - - def _convert_to_mode0(self, itspaces): - if self.mode == 0: - return tuple(itspaces) - elif self.mode == 1: - return tuple((ofs, ofs+size) for size, ofs in itspaces) - elif self.mode == 2: - return tuple((l.start, l.end) for l in itspaces) - - def _convert_from_mode0(self, itspaces): - if self.mode == 0: - return itspaces - elif self.mode == 1: - return [Region(end-start, start) for start, end in itspaces] - elif self.mode == 2: - raise RuntimeError("Cannot convert from mode=0 to mode=2") - - def merge(self, itspaces, within=None): - """Merge contiguous, possibly overlapping iteration spaces. - For example (assuming ``self.mode = 0``): :: - - [(1,3), (4,6)] -> ((1,6),) - [(1,3), (5,6)] -> ((1,3), (5,6)) - - :arg within: an integer representing the distance between two iteration - spaces to be considered adjacent. Defaults to 1. - """ - itspaces = self._convert_to_mode0(itspaces) - within = within or 1 - - itspaces = sorted(tuple(set(itspaces))) - merged_itspaces = [] - current_start, current_stop = itspaces[0] - for start, stop in itspaces: - if start - within > current_stop: - merged_itspaces.append((current_start, current_stop)) - current_start, current_stop = start, stop - else: - # Ranges adjacent or overlapping: merge. - current_stop = max(current_stop, stop) - merged_itspaces.append((current_start, current_stop)) - - itspaces = self._convert_from_mode0(merged_itspaces) - return itspaces - - def intersect(self, itspaces): - """Compute the intersection of multiple iteration spaces. - For example (assuming ``self.mode = 0``): :: - - [(1,3)] -> () - [(1,3), (4,6)] -> () - [(1,3), (2,6)] -> (2,3) - """ - itspaces = self._convert_to_mode0(itspaces) - - if len(itspaces) == 0: - return () - elif len(itspaces) > 1: - itspaces = [set(range(i[0], i[1])) for i in itspaces] - itspace = set.intersection(*itspaces) - itspace = sorted(list(itspace)) or [0, -1] - itspaces = [(itspace[0], itspace[-1]+1)] - - itspace = self._convert_from_mode0(itspaces)[0] - return itspace - - def to_for(self, itspaces, dims=None, stmts=None): - """Create ``For`` objects starting from an iteration space.""" - if not dims and self.mode == 2: - dims = [l.dim for l in itspaces] - elif not dims: - dims = ['i%d' % i for i, j in enumerate(itspaces)] - - itspaces = self._convert_to_mode0(itspaces) - - loops = [] - body = Block(stmts or [], open_scope=True) - for (start, stop), dim in reversed(list(zip(itspaces, dims))): - new_for = For(Decl("int", dim, start), Less(dim, stop), Incr(dim, 1), body) - loops.insert(0, new_for) - body = Block([new_for], open_scope=True) - - return loops - - -############################################################### -# Utilities for tracking the global impact of transformations # -############################################################### - - -class StmtTracker(OrderedDict): - - """Track the location of generic statements in an abstract syntax tree. - - Each key in the dictionary is a string representing a symbol. As such, - StmtTracker can be used only in SSA scopes. Each entry in the dictionary - is a 4-tuple containing information about the symbol: :: - - (statement, declaration, closest_for, place) - - whose semantics is, respectively, as follows: - - * The AST node whose ``str(lvalue)`` is used as dictionary key - * The AST node of the symbol declaration - * The AST node of the closest loop enclosing the statement - * The parent of the closest loop - """ - - class StmtInfo(object): - """Simple container class defining ``StmtTracker`` values.""" - - INFO = ['stmt', 'decl', 'loop', 'place'] - - def __init__(self, **kwargs): - for k, v in iteritems(kwargs): - assert(k in self.__class__.INFO) - setattr(self, k, v) - - def __init__(self): - super(StmtTracker, self).__init__() - self.byvalue = OrderedDict() - - def __setitem__(self, key, value): - if not isinstance(value, self.StmtInfo): - if not isinstance(value, tuple): - raise RuntimeError("StmtTracker accepts tuple or StmtInfo objects") - assert len(self.StmtInfo.INFO) == len(value) - value = self.StmtInfo(**dict(zip(self.StmtInfo.INFO, value))) - self.byvalue[value.stmt.rvalue.urepr] = key - return OrderedDict.__setitem__(self, key, value) - - def update_stmt(self, sym, **kwargs): - """Given the symbol ``sym``, it updates information related to it as - specified in ``kwargs``. If ``sym`` is not present, return ``None``. - ``kwargs`` is based on the following special keys: - - * "stmt": change the statement - * "decl": change the declaration - * "loop": change the closest loop - * "place": change the parent the closest loop - """ - if sym not in self: - return None - for k, v in iteritems(kwargs): - assert(k in self.StmtInfo.INFO) - setattr(self[sym], k, v) - - def update_loop(self, loop_a, loop_b): - """Replace all occurrences of ``loop_a`` with ``loop_b`` in all entries.""" - - for sym, sym_info in self.items(): - if sym_info.loop == loop_a: - self.update_stmt(sym, **{'loop': loop_b}) - - def get_symbol(self, value): - """Return the key associated to the provided ``value``, or None if not - present.""" - return self.byvalue.get(value.urepr) - - @property - def stmt(self, sym): - return self[sym].stmt if self.get(sym) else None - - @property - def decl(self, sym): - return self[sym].decl if self.get(sym) else None - - @property - def loop(self, sym): - return self[sym].loop if self.get(sym) else None - - @property - def place(self, sym): - return self[sym].place if self.get(sym) else None - - @property - def all_stmts(self): - return set((stmt_info.stmt for stmt_info in self.values() if stmt_info.stmt)) - - @property - def all_places(self): - return set((stmt_info.place for stmt_info in self.values() if stmt_info.place)) - - @property - def all_loops(self): - return set((stmt_info.loop for stmt_info in self.values() if stmt_info.loop)) - - -class ExpressionGraph(object): - - """Track read-after-write dependencies between symbols.""" - - def __init__(self, node): - """Initialize the ExpressionGraph. - - :param node: root of the AST visited to initialize the ExpressionGraph. - """ - import networkx as nx - self.deps = nx.DiGraph() - writes = Find(Writer).visit(node) - for type, nodes in writes.items(): - for n in nodes: - if isinstance(n.rvalue, EmptyStatement): - continue - self.add_dependency(n.lvalue, n.rvalue) - - def add_dependency(self, sym, expr): - """Add dependency between ``sym`` and symbols appearing in ``expr``.""" - expr_symbols = Find(Symbol).visit(expr)[Symbol] - for es in expr_symbols: - self.deps.add_edge(sym.symbol, es.symbol) - - def has_dependency(self): - """Return True if a read-after-write (raw) or write-after-read (war) - dependency appears in the graph, False otherwise.""" - if self.deps.edges(): - sources, targets = zip(*self.deps.edges()) - return True if set(sources) & set(targets) else False - else: - return False - - def is_read(self, expr, target_sym=None): - """Return True if any symbols in ``expr`` is read by ``target_sym``, - False otherwise. If ``target_sym`` is None, Return True if any symbols - in ``expr`` are read by at least one symbol, False otherwise.""" - import networkx as nx - input_syms = Find(Symbol).visit(expr)[Symbol] - for s in input_syms: - if s.symbol not in self.deps: - continue - elif not target_sym: - if list(zip(*self.deps.in_edges(s.symbol))): - return True - elif nx.has_path(self.deps, target_sym.symbol, s.symbol): - return True - return False - - def is_written(self, expr, target_sym=None): - """Return True if any symbols in ``expr`` is written by ``target_sym``, - False otherwise. If ``target_sym`` is None, Return True if any symbols - in ``expr`` are written by at least one symbol, False otherwise.""" - import networkx as nx - input_syms = Find(Symbol).visit(expr)[Symbol] - for s in input_syms: - if s.symbol not in self.deps: - continue - elif not target_sym: - if list(zip(*self.deps.out_edges(s.symbol))): - return True - elif nx.has_path(self.deps, s.symbol, target_sym.symbol): - return True - return False - - def shares(self, symbols): - """Return an iterator of tuples, each tuple being a group of symbols - identifiers sharing the same reads.""" - groups = set() - for i in [set(self.reads(s)) for s in symbols]: - group = tuple(j for j in symbols if i.intersection(set(self.reads(j)))) - groups.add(group) - return list(groups) - - def readers(self, sym): - """Return the list of symbol identifiers that read from ``sym``.""" - return [i for i, j in self.deps.in_edges(sym)] - - def reads(self, sym): - """Return the list of symbol identifiers that ``sym`` reads from.""" - return [j for i, j in self.deps.out_edges(sym)] - - -######################## -# Simple support types # -######################## - - -Region = namedtuple('Region', ['size', 'ofs']) - - -############################# -# Generic utility functions # -############################# - - -any_in = lambda a, b: any(i in b for i in a) -flatten = lambda list: [i for l in list for i in l] -bind = lambda a, b: [(a, v) for v in b] -od_find_next = lambda a, b: a.values()[a.keys().index(b)+1] - - -def as_urepr(l): - convert = lambda i: i.urepr if isinstance(i, Symbol) else i - try: - converted = [convert(i) for i in l] - except TypeError: - converted = convert(l) - return tuple(converted) - - -def is_const_dim(d): - return isinstance(d, int) or (isinstance(d, str) and d.isdigit()) - - -def insert_at_elem(_list, elem, new_elem, ofs=0): - ofs = _list.index(elem) + ofs - new_elem = [new_elem] if not isinstance(new_elem, list) else new_elem - for e in reversed(new_elem): - _list.insert(ofs, e) - - -def uniquify(exprs): - """Iterate over ``exprs`` and return a list of expressions in which duplicates - have been discarded. This function considers two expressions identical if they - have the same string representation.""" - return OrderedDict([(e.urepr, e) for e in exprs]).values() - - -def remove_empty_loops(node): - """Remove all empty loops within node.""" - - for nest in visit(node, info_items=['fors'])['fors']: - to_remove = (None, None) - for loop, parent in reversed(nest): - if not loop.body or all(i == to_remove[0] for i in loop.body): - to_remove = (loop, parent) - if all(to_remove): - loop, parent = to_remove - parent.children.remove(loop) - - -def remove_unused_decls(node): - """Remove all unused decls within node, which must be of type :class:`Block`.""" - - assert isinstance(node, Block) - - decls = Find(Decl, with_parent=True).visit(node)[Decl] - references = visit(node, info_items=['symbol_refs'])['symbol_refs'] - for d, p in decls: - if len(references[d.sym.symbol]) == 1: - p.children.remove(d) - - -def cleanup(node): - """Remove useless nodes in the AST rooted in node.""" - - remove_empty_loops(node) - remove_unused_decls(node) - - -def postprocess(node): - """Rearrange the Nodes in the AST rooted in ``node`` to improve the code quality - when unparsing the tree.""" - - class Process(object): - start = None - end = None - decls = {} - blockable = [] - _processed = [] - - @staticmethod - def mark(node): - if Process.start is not None: - Process._processed.append((node, Process.start, Process.end, - Process.decls, Process.blockable)) - Process.start = None - Process.end = None - Process.decls = {} - Process.blockable = [] - - def init_decl(node): - lhs, rhs = node.children - decl = Process.decls.get(lhs.symbol) - if decl and (not decl.init or isinstance(decl.init, EmptyStatement)): - decl.init = rhs - Process.blockable.remove(decl) - return decl - else: - return node - - def update(node, parent): - index = parent.children.index(node) - if Process.start is None: - Process.start = index - Process.end = index - - def make_blocks(): - for node, start, end, _, blockable in reversed(Process._processed): - node.children[start:end+1] = [Block(blockable, open_scope=False)] - - def _postprocess(node, parent): - if isinstance(node, FlatBlock) and str(node).isspace(): - update(node, parent) - elif isinstance(node, (For, If, Switch, FunCall, FunDecl, FlatBlock, LinAlg, - Block, Root)): - Process.mark(parent) - for n in node.children: - _postprocess(n, node) - Process.mark(node) - elif isinstance(node, Decl): - if not (node.init and not isinstance(node.init, EmptyStatement)) and \ - not node.sym.rank: - Process.decls[node.sym.symbol] = node - update(node, parent) - Process.blockable.append(node) - elif isinstance(node, AugmentedAssign): - update(node, parent) - Process.blockable.append(node) - elif isinstance(node, Assign): - update(node, parent) - Process.blockable.append(init_decl(node)) - - _postprocess(node, None) - make_blocks() diff --git a/coffee/vectorizer.py b/coffee/vectorizer.py deleted file mode 100644 index 73ac9031..00000000 --- a/coffee/vectorizer.py +++ /dev/null @@ -1,740 +0,0 @@ -# This file is part of COFFEE -# -# COFFEE is Copyright (c) 2014, Imperial College London. -# Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import absolute_import, print_function, division -from six.moves import range - -from math import ceil -from copy import deepcopy as dcopy -from collections import OrderedDict, namedtuple -from itertools import product - -from .base import * -from .utils import * -from . import system -from .logger import warn -from coffee.visitors import Find - - -class VectStrategy(object): - - """Supported vectorization modes.""" - - """Generate scalar code suitable for compiler auto-vectorization""" - AUTO = 1 - - """Specialized (intrinsics-based) vectorization using padding""" - SPEC_PADD = 2 - - """Specialized (intrinsics-based) vectorization using a peeling loop""" - SPEC_PEEL = 3 - - """Specialized (intrinsics-based) vectorization composed with unroll-and-jam - of outer loops, padding (to enforce data alignment), and peeling of padded - iterations""" - SPEC_UAJ_PADD = 4 - - """Specialized (intrinsics-based) vectorization composed with unroll-and-jam - of outer loops and padding (to enforce data alignment)""" - SPEC_UAJ_PADD_FULL = 5 - - -class LoopVectorizer(object): - - def __init__(self, loop_opt, kernel=None): - self.kernel = kernel or loop_opt.header - self.header = loop_opt.header - self.loop = loop_opt.loop - self.exprs = loop_opt.exprs - self.nz_syms = loop_opt.nz_syms - - def autovectorize(self, p_dim=-1): - """Generate code suitable for compiler auto-vectorization. - - Three code transformations may be applied here: - - * Padding - * Data alignment - - OR, if the outermost loop has an interation space much larger than that - of the inner loops, - - * Data layout transposition - - Padding consists of three major steps: - - * Pad the innermost dimension of all n-dimensional arrays to the nearest - multiple of the vector length. - * Round up, to the nearest multiple of the vector length, the bounds of all - innermost loops in which padded arrays are accessed. - * Since padding may induce data alignment of multi-dimensional arrays - (in practice, this depends on the presence of offsets as well), add - suitable '#pragma' to innermost loops to tell the backend compiler - if this property holds. - - Padding works as follows. Assume a vector length of size 4, and consider - the following piece of code: :: - - void foo(int A[10][10]): - int B[10] = ... - for i = 0 to 10: - for j = 0 to 10: - A[i][j] = B[i][j] - - Once padding is applied, the code will look like: :: - - void foo(int A[10][10]): - int _A[10][12] = {{0.0}}; - int B[10][12] = ... - for i = 0 to 10: - for j = 0 to 12: - _A[i][j] = B[i][j] - - for i = 0 to 10: - for j = 0 to 10: - A[i][j] = _A[i][j] - - Extra care is taken if offsets (e.g. A[i+3][j+3] ...) are used. In such - a case, the buffer array '_A' in the example above can be vector-expanded: :: - - int _A[x][10][12]; - ... - - Where 'x' corresponds to the number of different offsets used in a given - iteration space along the innermost dimension. - - Finally, all arrays are decorated with suitable attributes to enforce - alignment to (the size in bytes of) the vector length. - - :arg p_dim: the array dimension that should be padded (defaults to the - innermost, or -1) - """ - info = visit(self.header, info_items=['decls', 'fors', 'symbols_dep', - 'symbols_mode', 'symbol_refs']) - - padded = self._pad(p_dim, info['decls'], info['fors'], info['symbols_dep'], - info['symbols_mode'], info['symbol_refs']) - if padded: - self._align_data(p_dim, info['decls']) - - def _pad(self, p_dim, decls, fors, symbols_dep, symbols_mode, symbol_refs): - """Apply padding.""" - to_invert = Find(Invert).visit(self.header)[Invert] - - # Loop increments different than 1 are unsupported - if any([l.increment != 1 for l, _ in flatten(fors)]): - return None - - DSpace = namedtuple('DSpace', ['region', 'nest', 'symbols']) - ISpace = namedtuple('ISpace', ['region', 'nest', 'bag']) - - buf_decl = None - for decl_name, decl in decls.items(): - if not decl.size or decl.is_pointer_type: - continue - - p_rank = decl.size[:p_dim] + (vect_roundup(decl.size[p_dim]),) - if decl.size[p_dim] == 1 or p_rank == decl.size: - continue - - if decl.scope == LOCAL: - decl.pad(p_rank) - continue - - # At this point we are sure /decl/ is a FunDecl argument - - # A) Can a buffer actually be allocated ? - symbols = [s for s, _ in symbol_refs[decl_name] if s is not decl.sym] - if not all(s.dim == decl.sym.dim and s.is_const_offset for s in symbols): - continue - periods = flatten([s.periods for s in symbols]) - if not all(p == 1 for p in periods): - continue - - # ... must be either READ or WRITE mode - modes = [symbols_mode[s][0] for s in symbols] - if not modes or any(m != modes[0] for m in modes): - continue - mode = modes[0] - if mode not in [READ, WRITE]: - continue - - # ... accesses to entries in /decl/ must be explicit in all loop nests - deps = OrderedDict((s, [l for l in symbols_dep[s] if l.dim in s.rank]) - for s in symbols) - if not all(s.dim == len(n) for s, n in deps.items()): - continue - - # ... organize symbols based on their dataspace - dspace_mapper = OrderedDict() - for s, n in deps.items(): - n.sort(key=lambda l: s.rank.index(l.dim)) - region = tuple(Region(l.size, l.start + i) for i, l in zip(s.strides, n)) - dspace = DSpace(region=region, nest=n, symbols=[]) - dspace_mapper.setdefault(dspace.region, dspace) - dspace.symbols.append(s) - - # ... is there any overlap in the memory accesses? Memory accesses must: - # - either completely overlap (they will be mapped to the same buffer) - # - OR be disjoint - will_break = False - for regions1, regions2 in product(dspace_mapper.keys(), dspace_mapper.keys()): - for r1, r2 in zip(regions1, regions2): - if ItSpace(mode=1).intersect([r1, r2]) not in [(0, 0), r1]: - will_break = True - if will_break: - continue - - # ... initialize buffer-related data - buf_name = '_' + decl_name - buf_nz = self.nz_syms.setdefault(buf_name, []) - - # ... determine the non zero-valued region in the buffer - for n, region in enumerate(dspace_mapper.keys()): - p_region = (Region(region[p_dim].size, 0),) - buf_nz.append((Region(1, n),) + region[:p_dim] + p_region) - - # ... replace symbols in the AST with proper buffer instances - itspace_mapper = OrderedDict() - for n, dspace in enumerate(dspace_mapper.values()): - itspace = ISpace(region=tuple((l.size, l.start) for l in dspace.nest), - nest=dspace.nest, bag=OrderedDict()) - itspace = itspace_mapper.setdefault(itspace.region, itspace) - for s in dspace.symbols: - original = Symbol(s.symbol, s.rank, s.offset) - s.symbol = buf_name - s.rank = (n,) + s.rank - s.offset = ((1, 0),) + s.offset[:p_dim] + ((1, 0),) - if s.urepr not in [i.urepr for i in itspace.bag.values()]: - itspace.bag[original] = Symbol(s.symbol, s.rank, s.offset) - - # ... insert the buffer into the AST - buf_dim = n + 1 - buf_rank = (buf_dim,) + decl.size - init = ArrayInit(np.ndarray(shape=(1,)*len(buf_rank), buffer=np.array(0.0))) - buf_decl = Decl(decl.typ, Symbol(buf_name, buf_rank), init, scope=BUFFER) - buf_decl.pad((buf_dim,) + p_rank) - self.header.children.insert(0, buf_decl) - - # C) Create a loop nest for copying data into/from the buffer - for itspace in itspace_mapper.values(): - - if mode == READ: - stmts = [Assign(b, s) for s, b in itspace.bag.items()] - copy_back = ItSpace(mode=2).to_for(itspace.nest, stmts=stmts) - insert_at_elem(self.header.children, buf_decl, copy_back[0], ofs=1) - - elif mode == WRITE: - # If extra information (a pragma) is present, telling that - # the argument does not need to be incremented because it does - # not contain any meaningful values, then we can safely write - # to it. This optimization may avoid useless increments - can_write = WRITE in decl.pragma and len(itspace_mapper) == 1 - op = Assign if can_write else Incr - stmts = [op(s, b) for s, b in itspace.bag.items()] - copy_back = ItSpace(mode=2).to_for(itspace.nest, stmts=stmts) - if to_invert: - insert_at_elem(self.header.children, to_invert[0], copy_back[0]) - else: - self.header.children.append(copy_back[0]) - - # D) Update the global data structures - decls[buf_name] = buf_decl - - return buf_decl - - def _align_data(self, p_dim, decls): - """Apply data alignment. This boils down to: - - * Decorate declarations with qualifiers for data alignment - * Round up the bounds (i.e. /start/ and /end/ points) of loops such - that all memory accesses get aligned to the vector length. Several - checks ensure the correctness of the transformation. - """ - vector_length = system.isa["dp_reg"] - align = system.compiler['align'](system.isa['alignment']) - - # Array alignment - for decl in decls.values(): - if decl.sym.rank and decl.scope == LOCAL: - decl.attr.append(align) - - # Loop bounds adjustment - for l in inner_loops(self.header): - should_round = True - - for stmt in l.body: - sym, expr = stmt.lvalue, stmt.rvalue - decl = decls[sym.symbol] - - # Condition A: the lvalue can be a scalar only if /stmt/ is not an - # augmented assignment, otherwise the extra iterations would alter - # its value - if not sym.rank and isinstance(stmt, AugmentedAssign): - should_round = False - break - - # Condition B: the fastest varying dimension of the lvalue must be /l/ - if sym.rank and not sym.rank[p_dim] == l.dim: - should_round = False - break - - # Condition C: the lvalue must have been padded - if sym.rank and decl.size[p_dim] != vect_roundup(decl.size[p_dim]): - should_round = False - break - - symbols = [sym] + Find(Symbol).visit(expr)[Symbol] - symbols = [s for s in symbols if s.rank and any(r == l.dim for r in s.rank)] - - # Condition D: the access pattern must be accessible - if any(not s.is_unit_period for s in symbols): - # Cannot infer the access pattern so must break - should_round = False - break - - # Condition E: extra iterations induced by bounds and offset rounding - # must not alter the computation - for s in symbols: - decl = decls[s.symbol] - index = s.rank.index(l.dim) - stride = s.strides[index] - extra = list(range(stride + l.size, stride + vect_roundup(l.size))) - # Do any of the extra iterations alter the computation ? - if any(i > decl.size[index] for i in extra): - # ... outside of the legal region, abort - should_round = False - break - if all(i >= decl.core[index] for i in extra): - # ... in the padded region, pass - continue - nz = list(self.nz_syms.get(s.symbol)) - if not nz: - # ... lacks the non zero-valued entries mapping, abort - should_round = False - break - # ... get the non zero-valued entries in the right dimension - nz_index = [] - for i in nz: - can_skip = False - for j, r in enumerate(s.rank[:index]): - if not is_const_dim(r): - continue - if not (i[j].ofs <= r < i[j].ofs + i[j].size): - # ... actually on a different outer dimension, safe - # to avoid this check - can_skip = True - if not can_skip: - nz_index.append(i[index]) - if any(ofs <= i < ofs + size for size, ofs in nz_index): - # ... writing to a non-zero region, abort - should_round = False - break - - if should_round: - l.end = vect_roundup(l.end) - if all(i % vector_length == 0 for i in [l.start, l.size]): - l.pragma.add(system.compiler["align_forloop"]) - l.pragma.add(system.compiler['force_simdization']) - - def _transpose_layout(self, decls): - dim = self.loop.dim - symbols = Find(Symbol).visit(self.loop)[Symbol] - symbols = [s for s in symbols if any(r == dim for r in s.rank) and s.dim > 1] - - # Cannot handle arrays with more than 2 dimensions - if any(s.dim > 2 for s in symbols): - return - - mapper = OrderedDict() - for s in symbols: - mapper.setdefault(decls[s.symbol], list()).append(s) - - for decl, syms in mapper.items(): - # Adjust the declaration - transposed_values = decl.init.values.transpose() - decl.init.values = transposed_values - decl.sym.rank = transposed_values.shape - - # Adjust the instances - for s in syms: - s.rank = tuple(reversed(s.rank)) - - def specialize(self, opts, factor=1): - """Generate code for specialized expression vectorization. Check for peculiar - memory access patterns in an expression and replace scalar code with highly - optimized vector code. Currently, the following patterns are supported: - - * Outer products - e.g. A[i]*B[j] - - Also, code generation is supported for the following instruction sets: - - * AVX - - The parameter ``opts`` can be used to drive the transformation process by - specifying one of the vectorization strategies in :class:`VectStrategy`. - """ - vs = VectStrategy - if opts not in [vs.SPEC_UAJ_PADD, vs.SPEC_UAJ_PADD_FULL, - vs.SPEC_PADD, vs.SPEC_PEEL]: - warn("Don't know how to specialize vectorization for %s" % opts) - if system.isa['inst_set'] == 'SSE': - warn("Don't know how to specialize vectorization for SSE") - - layout = None - for stmt, expr_info in self.exprs.items(): - if expr_info.dimension != 2: - continue - parent = expr_info.parent - linear_loops = expr_info.linear_loops - linear_loops_parents = expr_info.linear_loops_parents - - # Check if outer-product vectorization is actually doable - vect_len = system.isa["dp_reg"] - rows = linear_loops[0].size - if rows < vect_len: - continue - - op = OuterProduct(stmt, linear_loops, 'STORE') - - # Vectorisation - unroll_factor = factor if opts in [vs.SPEC_UAJ_PADD, vs.SPEC_UAJ_PADD_FULL] else 1 - rows_per_it = vect_len*unroll_factor - if opts == vs.SPEC_UAJ_PADD: - if rows_per_it <= rows: - body, layout = op.generate(rows_per_it) - else: - # Unroll factor too big - body, layout = op.generate(vect_len) - elif opts == SPEC_UAJ_PADD_FULL: - if rows <= rows_per_it or vect_roundup(rows) % rows_per_it > 0: - # Cannot unroll too much - body, layout = op.generate(vect_len) - else: - body, layout = op.generate(rows_per_it) - elif opts in [vs.SPEC_PADD, vs.SPEC_PEEL]: - body, layout = op.generate(vect_len) - - # Construct the remainder loop - if opts != vs.SPEC_UAJ_PADD_FULL and rows > rows_per_it and rows % rows_per_it > 0: - # Adjust bounds and increments of the main, layout and remainder loops - linear_outerloop = linear_loops[0] - peel_loop = dcopy(linear_loops) - bound = linear_outerloop.end - bound -= bound % rows_per_it - linear_outerloop.end, layout.end = bound, bound - peel_loop[0].init.init = Symbol(bound) - peel_loop[0].increment, peel_loop[1].increment = 1, 1 - # Append peeling loop after the main loop - linear_outerparent = linear_loops_parents[0].children - insert_at_elem(linear_outerparent, linear_outerloop, peel_loop[0], 1) - - # Replace scalar with vector code - ofs = parent.children.index(stmt) - parent.children[ofs:ofs] = body - parent.children.remove(stmt) - - # Insert the layout code right after the loop nest enclosing the expression - if layout: - insert_at_elem(self.header.children, expr_info.loops[0], layout, 1) - - -class OuterProduct(object): - - """Generate an intrinsics-based outer product vectorisation of a statement.""" - - def __init__(self, stmt, loops, mode): - self.stmt = stmt - self.loops = loops - self.mode = mode - - class Alloc(object): - - """Handle allocation of register variables. """ - - def __init__(self, tensor_size): - nres = max(system.isa["dp_reg"], tensor_size) - self.ntot = system.isa["avail_reg"] - self.res = [system.isa["reg"](v) for v in range(nres)] - self.var = [system.isa["reg"](v) for v in range(nres, self.ntot)] - self.i = system.isa - - def get_reg(self): - if len(self.var) == 0: - l = self.ntot * 2 # noqa: E741 - self.var += [self.i["reg"](v) for v in range(self.ntot, l)] - self.ntot = l - return self.var.pop(0) - - def free_regs(self, regs): - for r in reversed(regs): - self.var.insert(0, r) - - def get_tensor(self): - return self.res - - def _swap_reg(self, step, vrs): - """Swap values in a vector register. """ - - # Find inner variables - regs = [reg for node, reg in vrs.items() - if node.rank and node.rank[-1] == self.loops[1].dim] - - if step in [0, 2]: - return [Assign(r, system.isa["l_perm"](r, "5")) for r in regs] - elif step == 1: - return [Assign(r, system.isa["g_perm"](r, r, "1")) for r in regs] - elif step == 3: - return [] - - def _vect_mem(self, vrs, decls): - """Return a list of vector variable declarations representing - loads, sets, broadcasts. - - :arg vrs: dictionary that associates scalar variables to vector. - variables, for which it will be generated a corresponding - intrinsics load/set/broadcast. - :arg decls: list of scalar variables for which an intrinsics load/ - set/broadcast has already been generated, possibly updated - by this method. - """ - stmt = [] - for node, reg in vrs.items(): - if node.rank and node.rank[-1] in [l.dim for l in self.loops]: - exp = system.isa["symbol_load"](node.symbol, node.rank, node.offset) - else: - exp = system.isa["symbol_set"](node.symbol, node.rank, node.offset) - if not decls.get(node.gencode()): - decls[node.gencode()] = reg - stmt.append(Decl(system.isa["decl_var"], reg, exp)) - return stmt - - def _vect_expr(self, node, ofs, regs, decls, vrs): - """Turn a scalar expression into its intrinsics equivalent. - - :arg node: AST expression to be vectorized. - :arg ofs: contains the offset of the entry in the left hand side that - is being vectorized. - :arg regs: register allocator. - :arg decls: list of scalar variables for which an intrinsics load/ - set/broadcast has already been generated. - :arg vrs: dictionary that associates scalar variables to vector variables. - Updated every time a new scalar variable is encountered. - """ - if isinstance(node, Symbol): - if node.rank and self.loops[0].dim == node.rank[-1]: - # The symbol depends on the outer loop dimension, so add offset - n_ofs = tuple([(1, 0) for i in range(len(node.rank)-1)]) + ((1, ofs),) - node = Symbol(node.symbol, dcopy(node.rank), n_ofs) - node_ide = node.gencode() - if node_ide not in decls: - reg = [k for k in vrs.keys() if k.gencode() == node_ide] - if not reg: - vrs[node] = Symbol(regs.get_reg()) - return vrs[node] - else: - return vrs[reg[0]] - else: - return decls[node_ide] - else: - left = self._vect_expr(node.left, ofs, regs, decls, vrs) - right = self._vect_expr(node.right, ofs, regs, decls, vrs) - if isinstance(node, Sum): - return system.isa["add"](left, right) - elif isinstance(node, Sub): - return system.isa["sub"](left, right) - elif isinstance(node, Prod): - return system.isa["mul"](left, right) - elif isinstance(node, Div): - return system.isa["div"](left, right) - - def _incr_tensor(self, tensor, ofs, regs, out_reg): - """Add the right hand side contained in out_reg to tensor. - - :arg tensor: the left hand side of the expression being vectorized. - :arg ofs: contains the offset of the entry in the left hand side that - is being computed. - :arg regs: register allocator. - :arg out_reg: register variable containing the left hand side. - """ - if self.mode == 'STORE': - # Store in memory - sym = tensor.symbol - rank = tensor.rank - ofs = tensor.offset[:-2] + ((1, ofs),) + tensor.offset[-1:] - load = system.isa["symbol_load"](sym, rank, ofs) - return system.isa["store"](Symbol(sym, rank, ofs), - system.isa["add"](load, out_reg)) - elif self.mode == 'MOVE': - # Accumulate on a vector register - reg = Symbol(regs.get_tensor()[ofs], ()) - return Assign(reg, system.isa["add"](reg, out_reg)) - - def _restore_layout(self, regs, tensor): - """Restore the storage layout of the tensor. - - :arg regs: register allocator. - :arg tensor: the left hand side of the expression being vectorized. - """ - code = [] - t_regs = [Symbol(r, ()) for r in regs.get_tensor()] - n_regs = len(t_regs) - - # Create tensor symbols - tensor_syms = [] - for i in range(n_regs): - ofs = tensor.offset[:-2] + ((1, i),) + tensor.offset[-1:] - tensor_syms.append(Symbol(tensor.symbol, tensor.rank, ofs)) - - # Load LHS values from memory - if self.mode == 'STORE': - for i, j in zip(tensor_syms, t_regs): - load_sym = system.isa["symbol_load"](i.symbol, i.rank, i.offset) - code.append(Decl(system.isa["decl_var"], j, load_sym)) - - # In-register restoration of the tensor layout - perm = system.isa["g_perm"] - uphi = system.isa["unpck_hi"] - uplo = system.isa["unpck_lo"] - typ = system.isa["decl_var"] - vect_len = system.isa["dp_reg"] - # Do as many times as the unroll factor - spins = int(ceil(n_regs / vect_len)) - for i in range(spins): - # In-register permutations - tmp = [Symbol(regs.get_reg(), ()) for r in range(vect_len)] - code.append(Decl(typ, tmp[0], uphi(t_regs[1], t_regs[0]))) - code.append(Decl(typ, tmp[1], uplo(t_regs[0], t_regs[1]))) - code.append(Decl(typ, tmp[2], uphi(t_regs[2], t_regs[3]))) - code.append(Decl(typ, tmp[3], uplo(t_regs[3], t_regs[2]))) - code.append(Assign(t_regs[0], perm(tmp[1], tmp[3], 32))) - code.append(Assign(t_regs[1], perm(tmp[0], tmp[2], 32))) - code.append(Assign(t_regs[2], perm(tmp[3], tmp[1], 49))) - code.append(Assign(t_regs[3], perm(tmp[2], tmp[0], 49))) - regs.free_regs([s.symbol for s in tmp]) - - # Store LHS values in memory - for j in range(min(vect_len, n_regs - i * vect_len)): - ofs = i * vect_len + j - code.append(system.isa["store"](tensor_syms[ofs], t_regs[ofs])) - - return code - - def generate(self, rows): - """Generate the outer-product intrinsics-based vectorisation code. - - By default, the tensor computed by the outer product vectorization is - kept in memory, so the layout is restored by means of explicit load and - store instructions. The resulting code will therefore look like: :: - - for ... - for j - for k - for ... - A[j][k] = ...intrinsics-based outer product along ``j-k``... - for j - for k - A[j][k] = ...intrinsics-based code for layout restoration... - - The other possibility would be to keep the computed values in temporaries - after a suitable permutation of the loops in the nest; this variant can be - activated by passing ``mode='MOVE'``, but it is not recommended unless - loops are very small *and* a suitable permutation of the nest has been - chosen to minimize register spilling. - """ - cols = system.isa["dp_reg"] - tensor, expr = self.stmt.children - tensor_size = cols - - # Get source-level variables - regs = self.Alloc(tensor_size) - - # Adjust loops' increment - self.loops[0].incr.children[1] = Symbol(rows) - self.loops[1].incr.children[1] = Symbol(cols) - - stmts, decls, vrs = [], {}, {} - rows_per_col = rows // cols - rows_to_peel = rows % cols - peeling = 0 - for i in range(cols): - # Handle extra rows - if peeling < rows_to_peel: - nrows = rows_per_col + 1 - peeling += 1 - else: - nrows = rows_per_col - for j in range(nrows): - # Vectorize, declare allocated variables, increment tensor - ofs = j * cols - v_expr = self._vect_expr(expr, ofs, regs, decls, vrs) - stmts.extend(self._vect_mem(vrs, decls)) - incr = self._incr_tensor(tensor, i + ofs, regs, v_expr) - stmts.append(incr) - # Register shuffles - if rows_per_col + (rows_to_peel - peeling) > 0: - stmts.extend(self._swap_reg(i, vrs)) - - # Set initialising and tensor layout code - layout = self._restore_layout(regs, tensor) - if self.mode == 'STORE': - # Tensor layout - layout_loops = dcopy(self.loops) - layout_loops[0].incr.children[1] = Symbol(cols) - layout_loops[0].children = [Block([layout_loops[1]], open_scope=True)] - layout_loops[1].children = [Block(layout, open_scope=True)] - layout = layout_loops[0] - elif self.mode == 'MOVE': - # Initialiser - for r in regs.get_tensor(): - decl = Decl(system.isa["decl_var"], Symbol(r, ()), system.isa["setzero"]) - self.loops[1].body.insert(0, decl) - # Tensor layout - self.loops[1].body.extend(layout) - layout = None - - return (stmts, layout) - - -# Utility functions - -def vect_roundup(x): - """Return x rounded up to the vector length. """ - word_len = system.isa.get("dp_reg") or 1 - return int(ceil(x / word_len)) * word_len - - -def vect_rounddown(x): - """Return x rounded down to the vector length. """ - word_len = system.isa.get("dp_reg") or 1 - return x - (x % word_len) diff --git a/coffee/visitors/__init__.py b/coffee/visitors/__init__.py index 0ab4dc24..ab2ddd18 100644 --- a/coffee/visitors/__init__.py +++ b/coffee/visitors/__init__.py @@ -1,3 +1 @@ -from __future__ import absolute_import, print_function, division from coffee.visitors.utilities import * # noqa -from coffee.visitors.inspectors import * # noqa diff --git a/coffee/visitors/inspectors.py b/coffee/visitors/inspectors.py deleted file mode 100644 index 7a5ba4b4..00000000 --- a/coffee/visitors/inspectors.py +++ /dev/null @@ -1,685 +0,0 @@ -from __future__ import absolute_import, print_function, division -from coffee.visitor import Visitor -from coffee.base import READ, WRITE, LOCAL, EXTERNAL, Symbol, EmptyStatement, Writer -from collections import defaultdict, OrderedDict, Counter -import itertools - -__all__ = ["FindInnerLoops", "CheckPerfectLoop", "CountOccurences", - "FindLoopNests", "FindCoffeeExpressions", "SymbolReferences", - "SymbolDependencies", "SymbolModes", "SymbolDeclarations", - "SymbolVisibility", "Find", "FindExpression"] - - -class FindInnerLoops(Visitor): - - """Find all inner-most loops in an AST. - - Returns a list of the inner-most :class:`.For` loops or an empty - list if none were found.""" - - def visit_object(self, o): - return [] - - def visit_Node(self, o): - # Concatenate transformed children - ops, _ = o.operands() - args = [self.visit(op) for op in ops] - return list(itertools.chain(*args)) - - def visit_For(self, o): - # Check for loops in children - children = self.visit(o.children[0]) - if children: - # Yes, return those - return children - # No return ourselves - return [o] - - -class CheckPerfectLoop(Visitor): - - """ - Check if a Node is a perfect loop nest. - """ - - def visit_object(self, o, *args, **kwargs): - # Unhandled, return False to be safe. - return False - - def visit_Node(self, o, in_loop=False, *args, **kwargs): - # Assume all nodes are in a perfect loop if they're in a loop. - return in_loop - - def visit_For(self, o, in_loop=False, multi=False, *args, **kwargs): - if in_loop and multi: - return False - return self.visit(o.children[0], in_loop=True, multi=multi) - - def visit_Block(self, o, in_loop=False, multi=False, *args, **kwargs): - # Does this block contain multiple statements? - multi = multi or len(o.children) > 1 - return in_loop and all(self.visit(op, in_loop=in_loop, multi=multi) for op in o.children) - - -class CountOccurences(Visitor): - - @classmethod - def default_retval(cls): - return Counter() - - r"""Count all occurances of :class:`~.Symbol`\s in an AST. - - :arg key: a comparison key for the symbols. - :arg only_rvalues: optionally only count rvalues in statements. - - Returns a dict mapping symbol keys to number of occurrences. - """ - def __init__(self, key=lambda x: (x.symbol, x.rank), only_rvalues=False): - self.key = key - self.rvalues = only_rvalues - super(CountOccurences, self).__init__() - - def visit_object(self, o, ret=None, *args, **kwargs): - # Not a symbol, return identity for summation - return ret - - def visit_list(self, o, ret=None, *args, **kwargs): - # Walk list entries (since some operands methods return lists) - for entry in o: - ret = self.visit(entry, ret=ret, *args, **kwargs) - return ret - - def visit_Node(self, o, ret=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, *args, **kwargs) - return ret - - def visit_Writer(self, o, ret=None, *args, **kwargs): - if self.rvalues: - # Only counting rvalues, so don't walk lvalue - ops = o.children[1:] - else: - ops = o.children - for op in ops: - ret = self.visit(op, ret=ret, *args, **kwargs) - return ret - - def visit_Decl(self, o, ret=None, *args, **kwargs): - if self.rvalues: - # Only counting rvalues, so don't walk lvalue - ret = self.visit(o.init, ret=ret, *args, **kwargs) - else: - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, *args, **kwargs) - return ret - - def visit_Symbol(self, o, ret=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - ret[self.key(o)] += 1 - return ret - - -class FindLoopNests(Visitor): - - @classmethod - def default_retval(cls): - return list() - - """Return a list of lists of loop nests in the tree. - - Each list entry describes a loop nest with the outer-most loop - first. Each entry therein is a tuple (loop_node, parent). - - By default the top-level call to visit will record a parent - of None for the visited Node. To provide one, pass a keyword - argument in to the visitor:: - - .. code-block:: - - v.visit(node, parent=parent) - - """ - - def visit_object(self, o, ret=None, *args, **kwargs): - return ret - - def visit_Node(self, o, ret=None, parent=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - # Visit children recording this node as the parent - ret = self.visit(op, ret=ret, parent=o) - return ret - - def visit_For(self, o, ret=None, parent=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - ops, _ = o.operands() - nval = len(ret) - for op in ops: - ret = self.visit(op, ret=ret, parent=o) - # Cons (node, node_parent) onto front of current loop-nest list - me = (o, parent) - if len(ret) == nval: - # Bottom of the nest, add myself to ret - ret.append([me]) - return ret - # Transform new children (inside this loop) - # [a, b] into [(me, a), (me, b)] - for a in ret[nval:]: - a.insert(0, me) - return ret - - -class FindCoffeeExpressions(Visitor): - - @classmethod - def default_retval(cls): - return OrderedDict() - - """ - Search the tree for :class:`~.Writer` statements annotated with - :data:`"#pragma coffee expression"`. Return a dict mapping the - annotated node to a tuple of (node_parent, containing_loop_nest, - index_access). - - By default the top-level call to visit will record a node_parent - of None for the visited Node. To provide one, pass a keyword - argument in to the visitor:: - - .. code-block:: - - v.visit(node, parent=parent) - - """ - - def visit_object(self, o, ret=None, *args, **kwargs): - return ret - - def visit_Node(self, o, ret=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, parent=o) - return ret - - def visit_Writer(self, o, ret=None, parent=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - for p in o.pragma: - opts = p.split(" ", 2) - # Don't care if we don't have three values - if len(opts) < 3: - continue - if opts[1] == "coffee" and opts[2] == "expression": - # (parent, loop-nest) - ret[o] = (parent, None) - return ret - return ret - - def visit_For(self, o, ret=None, parent=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - nval = len(ret) - - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, parent=o) - - # Nothing inside this for loop was annotated (we didn't see a - # Writer node with #pragma coffee expression) - if len(ret) == nval: - return ret - me = (o, parent) - # Add nest structure to new items - keys = list(ret.keys())[nval:] - for k in keys: - p, nest = ret[k] - if nest is None: - # Statement is directly underneath this loop, so the - # loop nest structure is just the current loop - nest = [me] - else: - # Inside a nested set of loops, so prepend current - # loop info to nest structure - nest = [me] + nest - ret[k] = p, nest - return ret - - -class SymbolReferences(Visitor): - - @classmethod - def default_retval(cls): - return defaultdict(list) - - """ - Visit the tree and return a dict mapping symbol names to tuples of - (node, node_parent) that reference the symbol with that name. - The node is the Symbol node with said name, the node_parent is the - parent of that node. - - By default the top-level call to visit will record a node_parent - of None for the visited Node. To provide one, pass a keyword - argument in to the visitor:: - - .. code-block:: - - v.visit(node, parent=parent) - - """ - - def visit_Symbol(self, o, ret=None, parent=None): - if ret is None: - ret = self.default_retval() - - # Map name to (node, parent) tuple - ret[o.symbol].append((o, parent)) - return ret - - def visit_ArrayInit(self, o, ret=None, *args, **kwargs): - for entry in o.values: - ret = self.visit(entry, ret=ret, *args, **kwargs) - return ret - - def visit_object(self, o, ret=None, *args, **kwargs): - # Identity - return ret - - def visit_list(self, o, ret=None, *args, **kwargs): - for entry in o: - ret = self.visit(entry, ret=ret, *args, **kwargs) - return ret - - def visit_Node(self, o, ret=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, parent=o) - return ret - - -class SymbolDependencies(Visitor): - - @classmethod - def default_retval(cls): - return OrderedDict() - - """ - Visit the tree and return a dict collecting symbol dependencies. - - The returned dict contains maps from nodes to a (possibly - empty) loop list the symbol depends on. - """ - - default_args = dict(loop_nest=[], write=False) - - def visit_Symbol(self, o, ret=None, *args, **kwargs): - write = kwargs["write"] - nest = kwargs["loop_nest"] - if ret is None: - ret = self.default_retval() - if write: - # Remember that this symbol /name/ was written, - # as well as the full current loop nest for the - # symbol itself - ret[o] = [l for l in nest] - ret[o.symbol] = True - else: - # Not being written, only care if the loop indices - # of the current nest access the symbol - ret[o] = [l for l in nest if l.dim in o.rank] - return ret - - def visit_object(self, o, ret=None, *args, **kwargs): - return ret - - def visit_Node(self, o, ret=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, *args, **kwargs) - return ret - - visit_EmptyStatement = visit_object - - def visit_Decl(self, o, ret=None, *args, **kwargs): - write = kwargs.pop("write") - ret = self.visit(o.sym, ret=ret, write=True, *args, **kwargs) - # Declaration init could have symbol access - ret = self.visit(o.init, ret=ret, write=write, *args, **kwargs) - return ret - - visit_FunCall = visit_Node - - def visit_Invert(self, o, ret=None, *args, **kwargs): - return self.visit(o.children[0], ret=ret, *args, **kwargs) - - def visit_Writer(self, o, ret=None, *args, **kwargs): - write = kwargs.pop("write") - ret = self.visit(o.children[0], ret=ret, write=True, *args, **kwargs) - for op in o.children[1:]: - ret = self.visit(op, ret=ret, write=write, *args, **kwargs) - return ret - - def visit_For(self, o, ret=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - loop_nest = kwargs.pop("loop_nest") + [o] - nval = len(ret) - # Don't care about symbol access in increments, only children - for op in o.children: - ret = self.visit(op, ret=ret, loop_nest=loop_nest, *args, **kwargs) - # Dependencies for variables that are written in one nest - # and read in a subsequent one need to respect this. - new_keys = set(list(ret.keys())[nval:]) - for k in new_keys: - if type(k) is not Symbol: - continue - if k.symbol in new_keys: - v = ret[k] - # Symbol name was written in some nest - # The dependency for this symbol is therefore - # whatever nest came from visiting the children - # plus the current nest at this point in the tree, - # suitably uniquified. - new_v = [l for l in loop_nest] - new_v.extend([l for l in v if l not in new_v]) - ret[k] = new_v - - return ret - - -class SymbolModes(Visitor): - - @classmethod - def default_retval(cls): - return OrderedDict() - - r""" - Visit the tree and return a dict mapping Symbols to tuples of - (access mode, parent class). - - :class:`~.Symbol`\s are accessed as READ-only unless they appear - as lvalues in a :class:`~.Writer` statement. - - By default the top-level call to visit will record a parent class - of NoneType for Symbols without a parent. To pass in a parent by - hand, provide a keyword argument to the visitor:: - - .. code-block:: - - v.visit(symbol, parent=parent) - - """ - - def visit_object(self, o, ret=None, *args, **kwargs): - return ret - - def visit_list(self, o, ret=None, *args, **kwargs): - for entry in o: - ret = self.visit(entry, ret=ret, *args, **kwargs) - return ret - - def visit_Node(self, o, ret=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - # WARNING, if the same Symbol object appears multiple - # times, the "last" access wins, rather than WRITE winning. - # This assumes all nodes in the tree are unique instances - ret = self.visit(op, ret=ret, parent=o) - return ret - - def visit_Symbol(self, o, ret=None, mode=READ, parent=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - ret[o] = (mode, parent.__class__) - return ret - - # Don't do anything with declarations. If you want lvalues to get - # a WRITE unless uninitialised, then custom visitor must be - # written. - def visit_Decl(self, o, ret=None, parent=None, *args, **kwargs): - if type(o.rvalue) is EmptyStatement: - mode = READ - else: - ret = self.visit(o.rvalue, ret=ret, parent=o) - mode = WRITE - ret = self.visit(o.lvalue, ret=ret, parent=o, mode=mode) - return ret - - def visit_Writer(self, o, ret=None, *args, **kwargs): - # lvalues have access mode WRITE - ret = self.visit(o.children[0], ret=ret, parent=o, mode=WRITE) - # All others have access mode READ - for op in o.children[1:]: - ret = self.visit(op, ret=ret, parent=o) - return ret - - visit_Invert = visit_Writer - - -class SymbolDeclarations(Visitor): - - @classmethod - def default_retval(cls): - return OrderedDict() - - """Return a dict mapping symbol names to a tuple of the declaring - node. The node is annotated in place with information about - whether it is a LOCAL declaration or EXTERNAL (via a function - argument). - """ - - def __init__(self): - super(SymbolDeclarations, self).__init__() - - def visit_object(self, o, ret=None, *args, **kwargs): - return ret - - def visit_Node(self, o, ret=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, *args, **kwargs) - return ret - - def visit_FunDecl(self, o, ret=None, *args, **kwargs): - for op in o.args: - ret = self.visit(op, ret=ret, scope=EXTERNAL) - for op in o.children: - ret = self.visit(op, ret=ret, *args, **kwargs) - return ret - - def visit_Decl(self, o, ret=None, scope=LOCAL, *args, **kwargs): - if ret is None: - ret = self.default_retval() - o.scope = scope - ret[o.sym.symbol] = o - return ret - - -class SymbolVisibility(Visitor): - - @classmethod - def default_retval(cls): - return defaultdict(list), [] - - """ - Visit the tree and return a dict mapping symbols to tuples of - scopes (AST sub-trees) in which they are legally accessible. - """ - - def __init__(self): - super(SymbolVisibility, self).__init__() - - def visit_Decl(self, o, ret=None, in_scope=None, *args, **kwargs): - if in_scope is not None: - in_scope.append(o) - return ret - - def visit_Block(self, o, ret=None, in_scope=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - if in_scope is None: - in_scope = [] - symbols_vis, scopes = ret - scopes.append(o) - this_scope = list(in_scope) - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, in_scope=this_scope, scopes=scopes) - for d in this_scope: - symbols_vis[d].insert(0, o) - return ret - - def visit_object(self, o, ret=None, *args, **kwargs): - # Identity - return ret - - def visit_list(self, o, ret=None, *args, **kwargs): - for entry in o: - ret = self.visit(entry, ret=ret, *args, **kwargs) - return ret - - def visit_Node(self, o, ret=None, *args, **kwargs): - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, *args, **kwargs) - return ret - - -class Find(Visitor): - - @classmethod - def default_retval(cls): - return defaultdict(list) - - """ - Visit the tree and return a dict mapping types to a list of - instances of that type in the tree. - - :arg types: list of types or single type to search for in the tree. - :arg stop_when_found: optional, don't traverse the children of matching types. - :arg with_parent: optional, track also the parent of the matching type. - """ - - def __init__(self, types, stop_when_found=False, with_parent=False): - self.types = types - self.stop_when_found = stop_when_found - self.with_parent = with_parent - super(Find, self).__init__() - - def useless_traversal(self, o): - """ - Return True if the traversal of the sub-tree rooted in o - is useless given that we are searching for nodes of type /t/ - - E.g., Writers cannot be nested. - """ - if isinstance(o, Writer) and self.types == Writer: - return True - return False - - def visit_object(self, o, ret=None, *args, **kwargs): - return ret - - def visit_list(self, o, ret=None, *args, **kwargs): - for entry in o: - ret = self.visit(entry, ret=ret, *args, **kwargs) - return ret - - def visit_Node(self, o, ret=None, parent=None, *args, **kwargs): - if ret is None: - ret = self.default_retval() - if isinstance(o, self.types): - found = (o, parent) if self.with_parent else o - ret[type(o)].append(found) - # Don't traverse children if stop-on-found - if self.stop_when_found: - return ret - if self.useless_traversal(o): - return ret - # Not found, or traversing children anyway - ops, _ = o.operands() - for op in ops: - ret = self.visit(op, ret=ret, parent=o) - return ret - - -class FindExpression(Visitor): - - @classmethod - def default_retval(cls): - return defaultdict(list) - - """ - Visit the expression tree and return a list of (sub-)expressions matching - particular criteria. - - :arg type: (optional) establish the matching expression' root operator(s), - such as Sum, Sub, ... - :arg dims: (optional) a tuple, each entry representing an iteration space - dimension. Expressions' symbols must iterate along one of these iteration - space dimensions. - :arg in_syms: (optional) expressions must include at least one of the symbols - in this argument. - :arg out_syms: (optional) expressions must exclude all of the symbols in - this argument. - """ - - def __init__(self, type=None, dims=None, in_syms=None, out_syms=None): - self.type = type - self.dims = dims - self.in_syms = in_syms - self.out_syms = out_syms or [] - super(FindExpression, self).__init__() - - def visit_object(self, o, *args, **kwargs): - return self.default_retval() - - def visit_Expr(self, o, parent=None, *args, **kwargs): - ret = self.default_retval() - for i in [self.visit(n, parent=o, *args, **kwargs) for n in o.children]: - for k, v in i.items(): - ret[k].extend([j for j in v if j not in ret[k]]) - if not ret['cleaned'] and all(i in ret for i in ['in_syms', 'in_itspace']): - # Create key for expression /o/ - key = set(ret['in_syms']) - key |= {j for j in ret['inner_syms']} - key = tuple(sorted(key)) - if not self.type or isinstance(o, self.type): - if self.type and isinstance(parent, self.type): - # Postpone expression tracking because the parent has same type - # as the node currently being visited - pass - else: - # Pop inner subexpressions, than push the parent one, since - # it represents a larger match - for k, v in ret.items(): - if set(k).issubset(key): - ret.pop(k) - # Does this subexpression /o/ include any of the forbidden symbols ? - if not any(i in key for i in ret['out_syms']): - # Yay, NO, track it - ret[key] = [o] - else: - # Yes. The first time we match an /out_sym/ we still have to go - # through the /else/ above to remove the inner subexpressions, - # but we should do it only once. 'cleaned' prevents popping - # from happening multiple times - ret['cleaned'] = [True] - else: - ret.pop('in_syms') - return ret - - visit_FunCall = visit_Expr - - def visit_Symbol(self, o, *args, **kwargs): - ret = self.default_retval() - if self.in_syms is None or o.symbol in self.in_syms: - ret['in_syms'] = [o.symbol] - ret['inner_syms'] = [o.symbol] - if o.symbol in self.out_syms: - ret['out_syms'] = [o.symbol] - if self.dims is None or any(r in self.dims for r in o.rank): - ret['in_itspace'] = [True] - return ret diff --git a/coffee/visitors/utilities.py b/coffee/visitors/utilities.py index 1b9a1749..13776348 100644 --- a/coffee/visitors/utilities.py +++ b/coffee/visitors/utilities.py @@ -1,322 +1,9 @@ -from __future__ import absolute_import, print_function, division -from six.moves import map, range - -import itertools -import operator -from copy import deepcopy -from collections import OrderedDict, defaultdict import numpy as np from coffee.visitor import Visitor -from coffee.base import Sum, Sub, Prod, Div, ArrayInit, SparseArrayInit - - -__all__ = ["ReplaceSymbols", "CheckUniqueness", "Uniquify", "Evaluate", - "EstimateFlops", "ProjectExpansion", "Reconstructor"] - - -class ReplaceSymbols(Visitor): - - """Replace named symbols in a tree, returning a new tree. - - :arg syms: A dict mapping symbol names to new Symbol objects. - :arg key: a callable to generate a key from a Symbol, defaults to - the string representation. - :arg copy_result: optionally copy the new Symbol whenever it is - used (guaranteeing that it will be unique)""" - def __init__(self, syms, key=lambda x: str(x), - copy_result=False): - self.syms = syms - self.key = key - self.copy_result = copy_result - super(ReplaceSymbols, self).__init__() - - def visit_Symbol(self, o): - try: - ret = self.syms[self.key(o)] - if self.copy_result: - ops, okwargs = ret.operands() - ret = ret.reconstruct(ops, **okwargs) - return ret - except KeyError: - return o - - def visit_object(self, o): - return o - - visit_Node = Visitor.maybe_reconstruct - - -class CheckUniqueness(Visitor): - - """ - Check if all nodes in a tree are unique instances. - """ - def visit_object(self, o, seen=None): - return seen - - # Some lists appear in operands() - def visit_list(self, o, seen=None): - # Walk list entrys - for entry in o: - seen = self.visit(entry, seen=seen) - return seen - - def visit_Node(self, o, seen=None): - if seen is None: - seen = set() - ops, _ = o.operands() - for op in ops: - seen = self.visit(op, seen=seen) - if o in seen: - raise RuntimeError("Tree does not contain unique nodes") - seen.add(o) - return seen - - -class Uniquify(Visitor): - """ - Uniquify all nodes in a tree by recursively calling reconstruct - """ - - visit_Node = Visitor.always_reconstruct - - def visit_object(self, o): - return deepcopy(o) - - def visit_list(self, o): - return [self.visit(e) for e in o] - - -class Evaluate(Visitor): - - @classmethod - def default_retval(cls): - return OrderedDict() - - """ - Symbolically evaluate an expression enclosed in a loop nest, provided that - all of the symbols involved are constants and their value is known. - - Return a dictionary mapping symbol names to (newly created) Decl nodes, each - declaration being initialized with a proper (newly computed and created) - ArrayInit object. - - :arg decls: dictionary mapping symbol names to known Decl nodes. - :arg track_zeros: True if the evaluated arrays are expected to be block-sparse - and the pattern of zeros should be tracked. - """ - - default_args = dict(loop_nest=[]) - - def __init__(self, decls, track_zeros): - self.decls = decls - self.track_zeros = track_zeros - self.mapper = { - Sum: np.add, - Sub: np.subtract, - Prod: np.multiply, - Div: np.divide - } - - import coffee.vectorizer - self.up = coffee.vectorizer.vect_roundup - self.down = coffee.vectorizer.vect_rounddown - from coffee.utils import ItSpace - self.make_itspace = ItSpace - super(Evaluate, self).__init__() - - def visit_object(self, o, *args, **kwargs): - return self.default_retval() - - def visit_list(self, o, *args, **kwargs): - ret = self.default_retval() - for entry in o: - ret.update(self.visit(entry, *args, **kwargs)) - return ret - - def visit_Node(self, o, *args, **kwargs): - ret = self.default_retval() - for n in o.children: - ret.update(self.visit(n, *args, **kwargs)) - return ret - - def visit_For(self, o, *args, **kwargs): - nest = kwargs.pop("loop_nest") - kwargs["loop_nest"] = nest + [o] - return self.visit(o.body, *args, **kwargs) - - def visit_Writer(self, o, *args, **kwargs): - lvalue = o.children[0] - writes = [l for l in kwargs["loop_nest"] if l.dim in lvalue.rank] - - # Evaluate the expression for each point in in the n-dimensional space - # represented by /writes/ - dims = tuple(l.dim for l in writes) - shape = tuple(l.size for l in writes) - values, precision = np.zeros(shape), None - for i in itertools.product(*[range(j) for j in shape]): - point = {d: v for d, v in zip(dims, i)} - expr_values, precision = self.visit(o.children[1], point=point, *args, **kwargs) - # The sum takes into account reductions - values[i] = np.sum(expr_values) - - # If values is not expected to be block-sparse, just return - if not self.track_zeros: - return {lvalue: ArrayInit(values)} - - # Sniff the values to check for the presence of zero-valued blocks: ... - # ... set default nonzero patten - nonzero = [[(i, 0)] for i in shape] - # ... track nonzeros in each dimension - nonzeros_bydim = values.nonzero() - mapper = [] - for nz_dim in nonzeros_bydim: - mapper_dim = defaultdict(set) - for i, nz in enumerate(nz_dim): - point = [] - # ... handle outer dimensions - for j in nonzeros_bydim[:-1]: - if j is not nz_dim: - point.append((j[i],)) - # ... handle the innermost dimension, which is treated "specially" - # to retain data alignment - for j in nonzeros_bydim[-1:]: - if j is not nz_dim: - point.append(tuple(range(self.down(j[i]), self.up(j[i]+1)))) - mapper_dim[nz].add(tuple(point)) - mapper.append(mapper_dim) - for i, dim in enumerate(mapper[:-1]): - # Group indices iff contiguous /and/ same codomain - def grouper(arg): - m, n = arg - return m-n, dim[n] - ranges = [] - for k, g in itertools.groupby(enumerate(sorted(dim.keys())), grouper): - group = list(map(operator.itemgetter(1), g)) - ranges.append((group[-1]-group[0]+1, group[0])) - nonzero[i] = ranges or nonzero[i] - # Group indices in the innermost dimension iff within vector length size - ranges, grouper = [], lambda n: self.down(n) - for k, g in itertools.groupby(sorted(mapper[-1].keys()), grouper): - group = list(g) - ranges.append((group[-1]-group[0]+1, group[0])) - nonzero[-1] = self.make_itspace(mode=1).merge(ranges or nonzero[-1], within=-1) - - return {lvalue: SparseArrayInit(values, precision, tuple(nonzero))} - - def visit_BinExpr(self, o, *args, **kwargs): - ops, _ = o.operands() - transformed = [self.visit(op, *args, **kwargs) for op in ops] - if any([a is None for a in transformed]): - return - values, precisions = zip(*transformed) - # Precisions must match - assert precisions.count(precisions[0]) == len(precisions) - # Return the result of the binary operation plus forward the precision - return self.mapper[o.__class__](*values), precisions[0] - - def visit_Symbol(self, o, *args, **kwargs): - try: - # Any time a symbol is encountered, we expect to know the /point/ of - # the iteration space which is being evaluated. In particular, - # /point/ is pushed (and then popped) on the environment by a Writer - # node. If /point/ is missing, that means the root of the visit does - # not enclose the whole iteration space, which in turn indicates an - # error in the use of the visitor. - point = kwargs["point"] - except KeyError: - raise RuntimeError("Unknown iteration space point.") - try: - decl = self.decls[o.symbol] - except KeyError: - raise RuntimeError("Couldn't find a declaration for symbol %s" % o) - try: - values = decl.init.values - precision = decl.init.precision - shape = values.shape - except AttributeError: - raise RuntimeError("%s not initialized with a numpy array" % decl) - sliced = 0 - for i, (r, s) in enumerate(zip(o.rank, shape)): - dim = i - sliced - # Three possible cases... - if isinstance(r, int): - # ...the index is used to access a specific dimension (e.g. A[5][..]) - values = values.take(r, dim) - sliced += 1 - elif r in point: - # ...a value is being evaluated along dimension /r/ (e.g. A[r] = B[..][r]) - values = values.take(point[r], dim) - sliced += 1 - else: - # .../r/ is a reduction dimension - values = values.take(list(range(s)), dim) - return values, precision - - -class ProjectExpansion(Visitor): - - @classmethod - def default_retval(cls): - return list() - - """ - Project the output of expression expansion. - The caller should provid a collection of symbols C. The expression tree (nodes - that are not of type :class:`~.Expr` are not allowed) is visited and a set of - tuples returned, one tuple for each symbol in C. Each tuple represents the subset - of symbols in C that will appear in at least one term after expansion. - - For example, be C = [a, b], and consider the following input expression: :: - (a*c + d*e)*(b*c + b*f) - After expansion, the expression becomes: :: - - a*c*b*c + a*c*b*f + d*e*b*c + d*e*b*f - - In which there are four product terms. In these terms, there are two in which - both 'a' and 'b' appear, and there are two in which only 'b' appears. So the - visit will return [(a, b), (b,)]. - - :arg symbols: the collection of symbols searched for - """ - - def __init__(self, symbols): - self.symbols = symbols - super(ProjectExpansion, self).__init__() - - def visit_object(self, o, *args, **kwargs): - return self.default_retval() - - def visit_Expr(self, o, parent=None, *args, **kwargs): - projection = self.default_retval() - for n in o.children: - projection.extend(self.visit(n, parent=o, *args, **kwargs)) - ret = [] - for n in projection: - if n not in ret: - ret.append(n) - return ret - - def visit_Prod(self, o, parent=None, *args, **kwargs): - from coffee.utils import flatten - if isinstance(parent, Prod): - projection = self.default_retval() - for n in o.children: - projection.extend(self.visit(n, parent=o, *args, **kwargs)) - return [list(flatten(projection))] - else: - # Only the top level Prod, in a chain of Prods, should do the - # tensor product - projection = [self.visit(n, parent=o, *args, **kwargs) for n in o.children] - product = itertools.product(*projection) - ret = [list(flatten(i)) for i in product] or projection - return ret - - def visit_Symbol(self, o, *args, **kwargs): - return [[o.symbol]] if o.symbol in self.symbols else [[]] +__all__ = ["EstimateFlops"] class EstimateFlops(Visitor): @@ -381,18 +68,3 @@ def visit_Determinant2x2(self, o, *args, **kwargs): def visit_Determinant3x3(self, o, *args, **kwargs): return 14 - - -class Reconstructor(Visitor): - - """ - Recursively reconstruct abstract syntax trees. - """ - - def visit_object(self, o): - return o - - def visit_Node(self, o): - ops, _ = o.operands() - reconstructed_operands = [self.visit(op) for op in ops] - return o.reconstruct(*reconstructed_operands) diff --git a/tests/test_visitors.py b/tests/test_visitors.py deleted file mode 100644 index d1e30400..00000000 --- a/tests/test_visitors.py +++ /dev/null @@ -1,455 +0,0 @@ -from __future__ import absolute_import, print_function, division - -import pytest -from coffee.base import * -from coffee.visitors import * -from collections import Counter -from functools import reduce - - -@pytest.mark.parametrize("key", - [lambda x: x.symbol, - lambda x: x, - lambda x: x.symbol == "a"], - ids=["symbol_name", "symbol_identity", - "symbol_name_is_a"]) -@pytest.mark.parametrize("symbols", - ["a", - "a,a", - "a,a,b", - "b"]) -def test_count_occurences_block(key, symbols): - v = CountOccurences(key=key) - - symbols = [Symbol(a) for a in symbols.split(",")] - tree = Block(symbols) - - expect = Counter() - for sym in symbols: - expect[key(sym)] += 1 - - assert v.visit(tree) == expect - - -@pytest.mark.parametrize("key", - [lambda x: x.symbol, - lambda x: x, - lambda x: x.symbol == "a"], - ids=["symbol_name", "symbol_identity", - "symbol_name_is_a"]) -@pytest.mark.parametrize("only_rvalues", - [False, True], - ids=["all_children", "only_rvalues"]) -@pytest.mark.parametrize("lvalue", - ["a", "b", "c"]) -@pytest.mark.parametrize("rvalue", - ["a,a", - "a,b,c", - "c", - "b", - "d"]) -def test_count_occurences_assign(key, only_rvalues, - lvalue, rvalue): - v = CountOccurences(key=key, only_rvalues=only_rvalues) - - rvalue = [Symbol(a) for a in rvalue.split(",")] - - lvalue = Symbol(lvalue) - - expect = Counter() - - if not only_rvalues: - expect[key(lvalue)] += 1 - - for sym in rvalue: - expect[key(sym)] += 1 - - rvalue = reduce(Prod, rvalue) - - tree = Assign(lvalue, rvalue) - - assert v.visit(tree) == expect - - -@pytest.mark.parametrize("structure", - ([], - [[]], - [None, []], - [None, [[], []]], - [None, [[None, [], [[]]]]])) -def test_find_inner_loops(structure): - v = FindInnerLoops() - - inner_loops = [] - - def build_loop(structure): - ret = [] - for entry in structure: - if entry is None: - continue - else: - loop = Block([build_loop(entry)]) - ret.append(loop) - loop = For(Symbol("a"), Symbol("b"), Symbol("c"), - Block(ret, open_scope=True)) - if ret == []: - inner_loops.append(loop) - return loop - - loop = build_loop(structure) - - expect = set(inner_loops) - - loops = v.visit(loop) - - assert set(loops) == expect - - -def test_check_perfect_loop(): - v = CheckPerfectLoop() - - a = Symbol("a") - b = Symbol("b") - loop = c_for("i", 10, [Assign(a, b)]).children[0] - - env = dict(in_loop=True, multiple_statements=False) - assert v.visit(loop, **env) - - loop2 = c_for("j", 10, [loop]).children[0] - - assert v.visit(loop2, **env) - - loop3 = c_for("k", 10, [loop2, Assign(b, a)]).children[0] - - assert not v.visit(loop3, **env) - - loop4 = c_for("k", 10, [Assign(a, b), Assign(b, a)]).children[0] - - assert v.visit(loop4, **env) - - -@pytest.fixture -def block_aa(): - a = Symbol("a") - return Block([a, a]) - - -@pytest.fixture -def fun_aa_in_args(): - a = Symbol("a") - return FunDecl("void", "foo", [a, a], Block([Assign(Symbol("b"), - Symbol("c"))])) - - -@pytest.fixture -def fun_aa_in_body(block_aa): - return FunDecl("void", "foo", [], block_aa) - - -@pytest.mark.parametrize("tree", - [block_aa(), - fun_aa_in_args(), - fun_aa_in_body(block_aa())], - ids=["block-repeated-aa", - "fundecl-repeated-aa-args", - "fundecl-repeated-aa-body"]) -def test_check_uniqueness(tree): - v = CheckUniqueness() - - with pytest.raises(RuntimeError): - v.visit(tree) - - -@pytest.mark.parametrize("tree", - [block_aa(), - fun_aa_in_args(), - fun_aa_in_body(block_aa())], - ids=["block-repeated-aa", - "fundecl-repeated-aa-args", - "fundecl-repeated-aa-body"]) -def test_uniquify(tree): - v = Uniquify() - check = CheckUniqueness() - - new_tree = v.visit(tree) - - with pytest.raises(RuntimeError): - check.visit(tree) - - assert check.visit(new_tree) - - -def test_symbol_declarations_decl(): - a = Symbol("a") - - tree = Decl("double", a) - - v = SymbolDeclarations() - - ret = v.visit(tree) - - assert set(ret.keys()) == set([a.symbol]) - - -def test_symbol_declarations_block(): - a = Symbol("a") - b = Symbol("b") - - tree = Block([Decl("int", a), - Decl("double", b)]) - - v = SymbolDeclarations() - - ret = v.visit(tree) - - assert set(ret.keys()) == set([a.symbol, b.symbol]) - - -def test_symbol_declarations_fundecl_args(): - a = Symbol("a") - b = Symbol("b") - - body = Block([Assign(b, a)]) - - tree = FunDecl("void", "foo", [Decl("double", a), Decl("double", b)], - body) - - v = SymbolDeclarations() - - ret = v.visit(tree) - assert set(ret.keys()) == set([a.symbol, b.symbol]) - - -def test_symbol_declarations_fundecl_body(): - a = Symbol("a") - b = Symbol("b") - - body = Block([Decl("int", a), - Decl("double", b)]) - - tree = FunDecl("void", "foo", [], - body) - - v = SymbolDeclarations() - - ret = v.visit(tree) - assert set(ret.keys()) == set([a.symbol, b.symbol]) - - -def test_symbol_declarations_fundecl_both(): - a = Symbol("a") - b = Symbol("b") - - body = Block([Decl("int", a), - Assign(a, b)]) - - tree = FunDecl("void", "foo", [Decl("int", b)], - body) - - v = SymbolDeclarations() - - ret = v.visit(tree) - assert set(ret.keys()) == set([a.symbol, b.symbol]) - - -def test_symbol_dependencies_no_nest(): - a = Symbol("a") - - tree = Assign(a, Symbol("1")) - - v = SymbolDependencies() - - ret = v.visit(tree, **SymbolDependencies.default_args) - - assert ret[a] == [] - - -def test_symbol_dependencies_single_loop(): - - a = Symbol("a") - i = Symbol("i") - tree = c_for(i, 2, [Assign(a, i)]) - - v = SymbolDependencies() - - ret = v.visit(tree, **SymbolDependencies.default_args) - - assert ret[a] == [tree.children[0]] - - -def test_symbol_dependencies_read_single_loop(): - a = Symbol("a", rank=("i", )) - b = Symbol("b") - tree = c_for("i", 2, [Assign(b, a)]) - - v = SymbolDependencies() - ret = v.visit(tree, **SymbolDependencies.default_args) - - assert ret[b] == [tree.children[0]] - - assert ret[a] == [tree.children[0]] - - -def test_symbol_dependencies_double_loop(): - a = Symbol("a", rank=("i", )) - b = Symbol("b") - tree = c_for("i", 2, [c_for("j", 1, [IMul(b, a)])]) - v = SymbolDependencies() - - ret = v.visit(tree, **SymbolDependencies.default_args) - - assert ret[b] == [tree.children[0], tree.children[0].body[0]] - assert ret[a] == [tree.children[0]] - - -def test_symbol_dependencies_write_then_read_inner_loop(): - a = Symbol("a") - a2 = Symbol("a") - b = Symbol("b") - tree = c_for("i", 2, [c_for("j", 1, [Assign(a, 1)]), - c_for("j", 2, [Assign(b, a2)])]) - - v = SymbolDependencies() - - ret = v.visit(tree, **SymbolDependencies.default_args) - - assert ret[a2] == [tree.children[0]] - assert ret[a] == [tree.children[0], tree.children[0].body[0].children[0]] - assert ret[b] == [tree.children[0], tree.children[0].body[1].children[0]] - - -def test_find_loop_nests_single(): - tree = c_for("i", 2, []) - v = FindLoopNests() - - ret = v.visit(tree) - - assert len(ret) == 1 - assert ret[0] == [(tree.children[0], tree)] - - -def test_find_loop_nests_nested(): - tree = c_for("i", 10, [c_for("j", 4, []), - c_for("k", 6, [c_for("l", 2, [])])]) - - v = FindLoopNests() - - ret = v.visit(tree) - - assert len(ret) == 2 - - iloop = tree.children[0] - jloop = iloop.children[0].children[0].children[0] - kloop = iloop.children[0].children[1].children[0] - lloop = kloop.children[0].children[0] - - assert ret[0][0][0] == iloop - assert ret[0][1][0] == jloop - assert len(ret[0]) == 2 - assert ret[1][0][0] == iloop - assert ret[1][1][0] == kloop - assert ret[1][2][0] == lloop - assert len(ret[1]) == 3 - - -def test_find_coffee_expressions_empty(): - tree = c_for("i", 10, []) - v = FindCoffeeExpressions() - - ret = v.visit(tree) - - assert len(ret) == 0 - - -def test_find_coffee_expressions_single(): - a = Symbol("a") - b = Symbol("b") - assign = Assign(a, b, pragma="#pragma coffee expression") - tree = c_for("i", 10, [assign]) - - v = FindCoffeeExpressions() - - ret = v.visit(tree) - - assert len(ret) == 1 - - val = ret[assign] - - assert len(val) == 2 - - assert val[1] == [(tree.children[0], tree)] - assert val[0] == tree.children[0].children[0] - - -def test_find_coffee_expressions_nested(): - a = Symbol("a") - b = Symbol("b") - assign1 = Assign(a, b, pragma="#pragma coffee expression") - - c = Symbol("c") - d = Symbol("d", rank=("i", )) - - assign2 = Assign(d, c, pragma="#pragma coffee expression") - tree = c_for("i", 10, [c_for("j", - 2, [assign2]), - assign1]) - - v = FindCoffeeExpressions() - ret = v.visit(tree) - - assert len(ret) == 2 - - val1 = ret[assign1] - val2 = ret[assign2] - - assert val1[0] == tree.children[0].children[0] - assert val1[1] == [(tree.children[0], tree)] - - assert val2[0] == tree.children[0].body[0].children[0].children[0] - assert val2[1] == [(tree.children[0], tree), - (tree.children[0].body[0].children[0], - tree.children[0].body[0])] - - -def test_symbol_modes_simple(): - a = Symbol("a") - b = Symbol("b") - tree = Assign(a, b) - - v = SymbolModes() - - ret = v.visit(tree) - - assert len(ret) == 2 - - assert ret[a] == (WRITE, tree.__class__) - assert ret[b] == (READ, tree.__class__) - - -def test_symbol_modes_nested(): - a = Symbol("a") - b = Symbol("b") - - assign = Assign(a, b) - - c = Symbol("c") - d = Symbol("a") - - assign2 = Assign(c, d) - - tree = c_for("i", 10, [assign, - c_for("j", 10, [c_for("k", 10, [assign2])])]) - - v = SymbolModes() - - ret = v.visit(tree) - - assert ret[a] == (WRITE, Assign) - assert ret[b] == (READ, Assign) - assert ret[c] == (WRITE, Assign) - assert ret[d] == (READ, Assign) - - -if __name__ == "__main__": - import os - pytest.main(os.path.abspath(__file__))