diff --git a/backend/FastChat/.gitignore b/backend/FastChat/.gitignore new file mode 100644 index 0000000..94b6e61 --- /dev/null +++ b/backend/FastChat/.gitignore @@ -0,0 +1,31 @@ +# Python +__pycache__ +*.pyc +*.egg-info +dist +.venv + +# Log +*.log +*.log.* +*.json +!playground/deepspeed_config_s2.json +!playground/deepspeed_config_s3.json + +# Editor +.idea +*.swp + +# Other +.DS_Store +wandb +output +checkpoints_flant5_3b + +# Data +*.pkl +*.csv +tests/state_of_the_union.txt + +# Build +build diff --git a/backend/FastChat/.pylintrc b/backend/FastChat/.pylintrc new file mode 100644 index 0000000..864033f --- /dev/null +++ b/backend/FastChat/.pylintrc @@ -0,0 +1,449 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MASTER] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party,ray_patches,providers + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat-in-sequence, + import-error, + import-self, + import-star-module-level, + inconsistent-return-statements, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + logging-format-interpolation, # FIXME(sky): make pass. + logging-fstring-interpolation, # FIXME(sky): make pass. + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-object-inheritance, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". This option is deprecated +# and it will be removed in Pylint 2.0. +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=100 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check= + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here we use 4 spaces. +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=StandardError, + Exception, + BaseException + +####### + +# https://github.com/edaniszewski/pylint-quotes#configuration +string-quote=single +triple-quote=double +docstring-quote=double diff --git a/backend/FastChat/Dockerfile b/backend/FastChat/Dockerfile new file mode 100644 index 0000000..18145a2 --- /dev/null +++ b/backend/FastChat/Dockerfile @@ -0,0 +1,15 @@ +FROM nvidia/cuda:12.2.0-runtime-ubuntu20.04 +ENV DEBIAN_FRONTEND noninteractive + +RUN apt-get update -y && apt-get install -y python3.9 python3.9-distutils curl git + +# Set working directory and copy 'fastchat' folder +WORKDIR /app +COPY fastchat /app/fastchat + +RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py +RUN python3.9 get-pip.py +RUN pip3 install fschat[model_worker,webui] pydantic==1.10.13 +RUN pip3 install huggingface_hub[cli] +RUN pip3 install bitsandbytes +RUN pip3 install scipy \ No newline at end of file diff --git a/backend/FastChat/LICENSE b/backend/FastChat/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/backend/FastChat/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/backend/FastChat/README.md b/backend/FastChat/README.md new file mode 100644 index 0000000..9f9c62e --- /dev/null +++ b/backend/FastChat/README.md @@ -0,0 +1,353 @@ +# FastChat +| [**Demo**](https://chat.lmsys.org/) | [**Discord**](https://discord.gg/HSWAKCrnFx) | [**X**](https://x.com/lmsysorg) | + +FastChat is an open platform for training, serving, and evaluating large language model based chatbots. +- FastChat powers Chatbot Arena (https://chat.lmsys.org/), serving over 5 million chat requests for 30+ LLMs. +- Arena has collected over 100K human votes from side-by-side LLM battles to compile an online [LLM Elo leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard). + +FastChat's core features include: +- The training and evaluation code for state-of-the-art models (e.g., Vicuna, MT-Bench). +- A distributed multi-model serving system with web UI and OpenAI-compatible RESTful APIs. + +## News +- [2023/09] 🔥 We released **LMSYS-Chat-1M**, a large-scale real-world LLM conversation dataset. Read the [report](https://arxiv.org/abs/2309.11998). +- [2023/08] We released **Vicuna v1.5** based on Llama 2 with 4K and 16K context lengths. Download [weights](#vicuna-weights). +- [2023/07] We released **Chatbot Arena Conversations**, a dataset containing 33k conversations with human preferences. Download it [here](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations). + +
+More + +- [2023/08] We released **LongChat v1.5** based on Llama 2 with 32K context lengths. Download [weights](#longchat). +- [2023/06] We introduced **MT-bench**, a challenging multi-turn question set for evaluating chatbots. Check out the blog [post](https://lmsys.org/blog/2023-06-22-leaderboard/). +- [2023/06] We introduced **LongChat**, our long-context chatbots and evaluation tools. Check out the blog [post](https://lmsys.org/blog/2023-06-29-longchat/). +- [2023/05] We introduced **Chatbot Arena** for battles among LLMs. Check out the blog [post](https://lmsys.org/blog/2023-05-03-arena). +- [2023/03] We released **Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality**. Check out the blog [post](https://vicuna.lmsys.org). + +
+ + + +## Contents +- [Install](#install) +- [Model Weights](#model-weights) +- [Inference with Command Line Interface](#inference-with-command-line-interface) +- [Serving with Web GUI](#serving-with-web-gui) +- [API](#api) +- [Evaluation](#evaluation) +- [Fine-tuning](#fine-tuning) +- [Citation](#citation) + +## Install + +### Method 1: With pip + +```bash +pip3 install "fschat[model_worker,webui]" +``` + +### Method 2: From source + +1. Clone this repository and navigate to the FastChat folder +```bash +git clone https://github.com/lm-sys/FastChat.git +cd FastChat +``` + +If you are running on Mac: +```bash +brew install rust cmake +``` + +2. Install Package +```bash +pip3 install --upgrade pip # enable PEP 660 support +pip3 install -e ".[model_worker,webui]" +``` + +## Model Weights +### Vicuna Weights +[Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) is based on Llama 2 and should be used under Llama's [model license](https://github.com/facebookresearch/llama/blob/main/LICENSE). + +You can use the commands below to start chatting. It will automatically download the weights from Hugging Face repos. +See more command options and how to handle out-of-memory in the "Inference with Command Line Interface" section below. + +**NOTE: `transformers>=4.31` is required for 16K versions.** + +| Size | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| 7B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` | [lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | +| 7B-16k | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5-16k` | [lmsys/vicuna-7b-v1.5-16k](https://huggingface.co/lmsys/vicuna-7b-v1.5-16k) | +| 13B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-13b-v1.5` | [lmsys/vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5) | +| 13B-16k | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-13b-v1.5-16k` | [lmsys/vicuna-13b-v1.5-16k](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k) | +| 33B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-33b-v1.3` | [lmsys/vicuna-33b-v1.3](https://huggingface.co/lmsys/vicuna-33b-v1.3) | + +**Old weights**: see [docs/vicuna_weights_version.md](docs/vicuna_weights_version.md) for all versions of weights and their differences. + +### LongChat +We release [LongChat](https://lmsys.org/blog/2023-06-29-longchat/) models under Llama's [model license](https://github.com/facebookresearch/llama/blob/main/LICENSE). + +| Size | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| 7B | `python3 -m fastchat.serve.cli --model-path lmsys/longchat-7b-32k-v1.5` | [lmsys/longchat-7b-32k](https://huggingface.co/lmsys/longchat-7b-32k-v1.5) | + +### FastChat-T5 +You can use the commands below to chat with FastChat-T5. It will automatically download the weights from Hugging Face repos. + +| Size | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| 3B | `python3 -m fastchat.serve.cli --model-path lmsys/fastchat-t5-3b-v1.0` | [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5-3b-v1.0) | + +## Inference with Command Line Interface + + + +(Experimental Feature: You can specify `--style rich` to enable rich text output and better text streaming quality for some non-ASCII content. This may not work properly on certain terminals.) + +#### Supported Models +FastChat supports a wide range of models, including +LLama 2, Vicuna, Alpaca, Baize, ChatGLM, Dolly, Falcon, FastChat-T5, GPT4ALL, Guanaco, MTP, OpenAssistant, OpenChat, RedPajama, StableLM, WizardLM, and more. + +See a complete list of supported models and instructions to add a new model [here](docs/model_support.md). + +#### Single GPU +The command below requires around 14GB of GPU memory for Vicuna-7B and 28GB of GPU memory for Vicuna-13B. +See the ["Not Enough Memory" section](#not-enough-memory) below if you do not have enough memory. +`--model-path` can be a local folder or a Hugging Face repo name. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 +``` + +#### Multiple GPUs +You can use model parallelism to aggregate GPU memory from multiple GPUs on the same machine. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --num-gpus 2 +``` + +Tips: +Sometimes the "auto" device mapping strategy in huggingface/transformers does not perfectly balance the memory allocation across multiple GPUs. +You can use `--max-gpu-memory` to specify the maximum memory per GPU for storing model weights. +This allows it to allocate more memory for activations, so you can use longer context lengths or larger batch sizes. For example, + +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --num-gpus 2 --max-gpu-memory 8GiB +``` + +#### CPU Only +This runs on the CPU only and does not require GPU. It requires around 30GB of CPU memory for Vicuna-7B and around 60GB of CPU memory for Vicuna-13B. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device cpu +``` + +Use Intel AI Accelerator AVX512_BF16/AMX to accelerate CPU inference. +``` +CPU_ISA=amx python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device cpu +``` + +#### Metal Backend (Mac Computers with Apple Silicon or AMD GPUs) +Use `--device mps` to enable GPU acceleration on Mac computers (requires torch >= 2.0). +Use `--load-8bit` to turn on 8-bit compression. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device mps --load-8bit +``` +Vicuna-7B can run on a 32GB M1 Macbook with 1 - 2 words / second. + +#### Intel XPU (Intel Data Center and Arc A-Series GPUs) +Install the [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html). Set the OneAPI environment variables: +``` +source /opt/intel/oneapi/setvars.sh +``` + +Use `--device xpu` to enable XPU/GPU acceleration. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device xpu +``` +Vicuna-7B can run on an Intel Arc A770 16GB. + +#### Ascend NPU (Huawei AI Processor) +Install the [Ascend PyTorch Adapter](https://github.com/Ascend/pytorch). Set the CANN environment variables: +``` +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +Use `--device npu` to enable NPU acceleration. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device npu +``` +Vicuna-7B/13B can run on an Ascend 910B NPU 60GB. + +#### Not Enough Memory +If you do not have enough memory, you can enable 8-bit compression by adding `--load-8bit` to commands above. +This can reduce memory usage by around half with slightly degraded model quality. +It is compatible with the CPU, GPU, and Metal backend. + +Vicuna-13B with 8-bit compression can run on a single GPU with 16 GB of VRAM, like an Nvidia RTX 3090, RTX 4080, T4, V100 (16GB), or an AMD RX 6800 XT. + +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --load-8bit +``` + +In addition to that, you can add `--cpu-offloading` to commands above to offload weights that don't fit on your GPU onto the CPU memory. +This requires 8-bit compression to be enabled and the bitsandbytes package to be installed, which is only available on linux operating systems. + +#### More Platforms and Quantization +- For AMD GPU users, please install ROCm and [the ROCm version of PyTorch](https://pytorch.org/get-started/locally/) before you install FastChat. See also this [post](https://github.com/lm-sys/FastChat/issues/104#issuecomment-1613791563). +- FastChat supports ExLlama V2. See [docs/exllama_v2.md](/docs/exllama_v2.md). +- FastChat supports GPTQ 4bit inference with [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). See [docs/gptq.md](/docs/gptq.md). +- FastChat supports AWQ 4bit inference with [mit-han-lab/llm-awq](https://github.com/mit-han-lab/llm-awq). See [docs/awq.md](/docs/awq.md). +- [MLC LLM](https://mlc.ai/mlc-llm/), backed by [TVM Unity](https://github.com/apache/tvm/tree/unity) compiler, deploys Vicuna natively on phones, consumer-class GPUs and web browsers via Vulkan, Metal, CUDA and WebGPU. + +## Serving with Web GUI + + + +To serve using the web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the webserver and model workers. You can learn more about the architecture [here](docs/server_arch.md). + +Here are the commands to follow in your terminal: + +#### Launch the controller +```bash +python3 -m fastchat.serve.controller +``` + +This controller manages the distributed workers. + +#### Launch the model worker(s) +```bash +python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 +``` +Wait until the process finishes loading the model and you see "Uvicorn running on ...". The model worker will register itself to the controller . + +To ensure that your model worker is connected to your controller properly, send a test message using the following command: +```bash +python3 -m fastchat.serve.test_message --model-name vicuna-7b-v1.5 +``` +You will see a short output. + +#### Launch the Gradio web server +```bash +python3 -m fastchat.serve.gradio_web_server +``` + +This is the user interface that users will interact with. + +By following these steps, you will be able to serve your models using the web UI. You can open your browser and chat with a model now. +If the models do not show up, try to reboot the gradio web server. + +#### (Optional): Advanced Features, Scalability +- You can register multiple model workers to a single controller, which can be used for serving a single model with higher throughput or serving multiple models at the same time. When doing so, please allocate different GPUs and ports for different model workers. +``` +# worker 0 +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --controller http://localhost:21001 --port 31000 --worker http://localhost:31000 +# worker 1 +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --controller http://localhost:21001 --port 31001 --worker http://localhost:31001 +``` +- You can also launch a multi-tab gradio server, which includes the Chatbot Arena tabs. +```bash +python3 -m fastchat.serve.gradio_web_server_multi +``` +- The default model worker based on huggingface/transformers has great compatibility but can be slow. If you want high-throughput batched serving, you can try [vLLM integration](docs/vllm_integration.md). + +#### (Optional): Advanced Features, Third Party UI +- if you want to host it on your own UI or third party UI. Launch the OpenAI compatible server, host with a hosting service like ngrok, and enter the credentials approriatly. + - https://github.com/WongSaang/chatgpt-ui + - https://github.com/mckaywrigley/chatbot-ui +- Note some third party provider only offer the stand `gpt-3.5-turbo, gpt-4, etc`, so you will have to add your own custom model inside the code. [Here is an example of a modification of creating a UI with any custom model name](https://github.com/ztjhz/BetterChatGPT/pull/461) + + +## API +### OpenAI-Compatible RESTful APIs & SDK +FastChat provides OpenAI-compatible APIs for its supported models, so you can use FastChat as a local drop-in replacement for OpenAI APIs. +The FastChat server is compatible with both [openai-python](https://github.com/openai/openai-python) library and cURL commands. +See [docs/openai_api.md](docs/openai_api.md). + +### Hugging Face Generation APIs +See [fastchat/serve/huggingface_api.py](fastchat/serve/huggingface_api.py). + +### LangChain Integration +See [docs/langchain_integration](docs/langchain_integration.md). + +## Evaluation +We use MT-bench, a set of challenging multi-turn open-ended questions to evaluate models. +To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as judges and assess the quality of the models' responses. +See instructions for running MT-bench at [fastchat/llm_judge](fastchat/llm_judge). + +MT-bench is the new recommended way to benchmark your models. If you are still looking for the old 80 questions used in the vicuna blog post, please go to [vicuna-blog-eval](https://github.com/lm-sys/vicuna-blog-eval). + +## Fine-tuning +### Data + +Vicuna is created by fine-tuning a Llama base model using approximately 125K user-shared conversations gathered from ShareGPT.com with public APIs. To ensure data quality, we convert the HTML back to markdown and filter out some inappropriate or low-quality samples. Additionally, we divide lengthy conversations into smaller segments that fit the model's maximum context length. For detailed instructions to clean the ShareGPT data, check out [here](docs/commands/data_cleaning.md). + +We will not release the ShareGPT dataset. If you would like to try the fine-tuning code, you can run it with some dummy conversations in [dummy_conversation.json](data/dummy_conversation.json). You can follow the same format and plug in your own data. + +### Code and Hyperparameters +Our code is based on [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) with additional support for multi-turn conversations. +We use similar hyperparameters as the Stanford Alpaca. + +| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | +| --- | ---: | ---: | ---: | ---: | ---: | +| Vicuna-13B | 128 | 2e-5 | 3 | 2048 | 0 | + +### Fine-tuning Vicuna-7B with Local GPUs + +- Install dependency +```bash +pip3 install -e ".[train]" +``` + +- You can use the following command to train Vicuna-7B with 4 x A100 (40GB). Update `--model_name_or_path` with the actual path to Llama weights and `--data_path` with the actual path to data. +```bash +torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --data_path data/dummy_conversation.json \ + --bf16 True \ + --output_dir output_vicuna \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 10 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True +``` + +Tips: +- If you are using V100 which is not supported by FlashAttention, you can use the [memory-efficient attention](https://arxiv.org/abs/2112.05682) implemented in [xFormers](https://github.com/facebookresearch/xformers). Install xformers and replace `fastchat/train/train_mem.py` above with [fastchat/train/train_xformers.py](fastchat/train/train_xformers.py). +- If you meet out-of-memory due to "FSDP Warning: When using FSDP, it is efficient and recommended... ", see solutions [here](https://github.com/huggingface/transformers/issues/24724#issuecomment-1645189539). +- If you meet out-of-memory during model saving, see solutions [here](https://github.com/pytorch/pytorch/issues/98823). + +### Other models, platforms and LoRA support +More instructions to train other models (e.g., FastChat-T5) and use LoRA are in [docs/training.md](docs/training.md). + +### Fine-tuning on Any Cloud with SkyPilot +[SkyPilot](https://github.com/skypilot-org/skypilot) is a framework built by UC Berkeley for easily and cost effectively running ML workloads on any cloud (AWS, GCP, Azure, Lambda, etc.). +Find SkyPilot documentation [here](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna) on using managed spot instances to train Vicuna and save on your cloud costs. + +## Citation +The code (training, serving, and evaluation) in this repository is mostly developed for or derived from the paper below. +Please cite it if you find the repository helpful. + +``` +@misc{zheng2023judging, + title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, + author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, + year={2023}, + eprint={2306.05685}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` + +We are also planning to add more of our research to this repository. diff --git a/backend/FastChat/fastchat/__init__.py b/backend/FastChat/fastchat/__init__.py new file mode 100644 index 0000000..c4feccf --- /dev/null +++ b/backend/FastChat/fastchat/__init__.py @@ -0,0 +1 @@ +__version__ = "0.2.33" diff --git a/backend/FastChat/fastchat/constants.py b/backend/FastChat/fastchat/constants.py new file mode 100644 index 0000000..53ed55c --- /dev/null +++ b/backend/FastChat/fastchat/constants.py @@ -0,0 +1,65 @@ +""" +Global constants. +""" + +from enum import IntEnum +import os + +REPO_PATH = os.path.dirname(os.path.dirname(__file__)) + +##### For the gradio web server +SERVER_ERROR_MSG = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES." +CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." +INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." +SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds." +# Maximum input length +INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) +# Maximum conversation turns +CONVERSATION_TURN_LIMIT = 50 +# Session expiration time +SESSION_EXPIRATION_TIME = 3600 +# The output dir of log files +LOGDIR = os.getenv("LOGDIR", ".") +# CPU Instruction Set Architecture +CPU_ISA = os.getenv("CPU_ISA") + + +##### For the controller and workers (could be overwritten through ENV variables.) +CONTROLLER_HEART_BEAT_EXPIRATION = int( + os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) +) +WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) +WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) +WORKER_API_EMBEDDING_BATCH_SIZE = int( + os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) +) + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 diff --git a/backend/FastChat/fastchat/conversation.py b/backend/FastChat/fastchat/conversation.py new file mode 100644 index 0000000..a0d7d7f --- /dev/null +++ b/backend/FastChat/fastchat/conversation.py @@ -0,0 +1,1363 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. +""" + +import dataclasses +from enum import auto, IntEnum +from typing import List, Any, Dict, Union, Tuple + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: Tuple[str] = ("USER", "ASSISTANT") + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = "" if system_prompt == "" else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ": " + + message.replace("\r\n", "\n").replace("\n\n", "\n") + ) + ret += "\n\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == "chatglm2" else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = "" + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f"[Round {i//2 + round_add_n}]{self.sep}" + + if message: + ret += f"{role}:{message}{self.sep}" + else: + ret += f"{role}:" + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = "" + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + "\n" + " " + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += "" + if message: + ret += role + ":" + message + seps[i % 2] + "\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "" + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = "" + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + assert name in conv_templates, f"{name} is not registered." + print(conv_templates[name].dict()) + return conv_templates[name].copy() + + +# An empty template for raw conversation. +register_conv_template( + Conversation( + name="raw", + system_message="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# A template with a one-shot conversation example +register_conv_template( + Conversation( + name="one_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "Got any creative ideas for a 10 year old’s birthday?", + ), + ( + "Assistant", + """Of course! Here are some creative ideas for a 10-year-old's birthday party: +1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises. +2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions. +3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars. +4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors. +5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants. +6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen. +7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges. +8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors. +Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""", + ), + ), + offset=2, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str=["###", "<|im_start|>", "<|im_end|>", "$$$", "thank you for your help"] + ) +) + +# A template similar to the "one_shot" template above but remove the example. +register_conv_template( + Conversation( + name="zero_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str=["###", "<|im_start|>", "<|im_end|>", "$$$", "thank you for your help"] + ) +) + +system_msg_trials = "You are a helpful, respectful and honest assistant." + \ + "Always answer as helpfully as possible, while being safe." + \ + "Use only the following context about clinical trials to answer the user's question." + \ + "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content." + \ + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + \ + "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct." + \ + "If you don't know the answer to a question, please don't share false information, or invent information that is not in the context. Make an analysis to respond the question only with the information of the context that match the question""" + +system_msg_meditron = "You are a helpful, respectful and honest assistant." + \ + "Always answer as helpfully as possible, while being safe." + \ + "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content." + \ + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + \ + "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct." + \ + "If you don't know the answer to a question, please don't share false information.""" + +register_conv_template( + Conversation( + name="zero_shot_medical", + system_message=system_msg_meditron, + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str=["<\s>", "###", "Assistant:", "User:"] + ) +) + +register_conv_template( + Conversation( + name="one_shot_medical", + system_message=system_msg_meditron, + roles=("User", "Assistant"), + messages=( + ( + "User", + "What happens if listeria is left untreated?", + ), + ( + "Assistant", + """If listeria infection, or listeriosis, is left untreated, it can lead to severe health complications, particularly in certain high-risk groups. Here's a general overview of the potential outcomes: + +1. Spread of the Infection: Untreated listeriosis can spread beyond the gut to other parts of the body, including the nervous system. This can lead to more severe conditions like meningitis (inflammation of the membranes surrounding the brain and spinal cord) and septicemia (a serious blood infection). +2. Increased Risk for Certain Groups: Pregnant women, newborns, the elderly, and individuals with weakened immune systems are at a higher risk of severe complications. In pregnant women, listeriosis can lead to miscarriage, stillbirth, premature delivery, or life-threatening infection of the newborn. +3. Neurological Effects: Listeriosis can cause severe neurological symptoms like headaches, stiff neck, confusion, loss of balance, and convulsions, especially when the infection spreads to the nervous system. +4. Long-Term Health Impacts: For some, particularly those with pre-existing health conditions or weakened immune systems, the health impacts of listeriosis can be long-lasting and may not fully resolve even with treatment. +5. Fatalities: In severe cases, particularly among high-risk groups, listeriosis can be fatal. + +It's important to note that early diagnosis and appropriate treatment, typically with antibiotics, can greatly improve the prognosis for those with listeriosis. Therefore, seeking medical attention promptly if listeriosis is suspected is crucial.""" + ) + ), + offset=2, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str=["<\s>", "###", "Assistant:", "User:"] + ) +) + +register_conv_template( + Conversation( + name="one_shot_trials", + system_message=system_msg_trials, + roles=("User", "Assistant"), + messages=( + ( + "User", + """*** Context *** +page_content='Inclusion Criteria: +Clinical diagnosis of epithelial ovarian cancer stage I +nctId: NCT05212779 + +page_content='Inclusion Criteria: +Women with stage III-IV ovarian cancer, undergoing interval (after 3-4 cycles of chemotherapy) or delayed (\>5 cycles of chemotherapy) cytoreductive surgeries or no cytoreductive surgery (\>5 cycles of chemotherapy alone. +nctId: NCT05523804 + +page_content='Inclusion Criteria: +* Minimum age 18 years +* Signed informed consent form +* Confirmed diagnosis of ovarian cancer except low grade serous, clear cell and mucinous histology +* Where patients are treatment naïve, patients need to have disease stage FIGO (International Federation of Gynecology and Obstetrics) III or FIGO IV. +* Patient is expected to receive primary chemotherapy/maintenance after initial surgical debulking or a further line of systemic therapy in the relapsed setting according to treatment guidelines +* Feasibility of collecting malignant ascites and/or pleural effusion during either primary debulking surgery or a routine drainage procedure prior to initiation of the first or next line of systemic therapy +* ECOG (Eastern Cooperative Oncology Group) stage 0-2 +nctId: NCT06068738 + +*** User's Question *** +Is there any clinical trial for ovarian stage III cancer at this time?""", + ), + ( + "Assistant", + """There are currently several clinical trials for ovarian stage III cancer at this time. Some examples include: NCT05523804, and NCT06068738. These trials may have specific criteria for inclusion, so it is important to carefully review the details before determining eligibility. Please consult with a healthcare provider for more information about participating in a clinical trial.""" + ) + ), + offset=2, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str=["<\s>", "###", "Assistant:", "User:"] + ) +) + +# Vicuna v1.1 template +register_conv_template( + Conversation( + name="vicuna_v1.1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " + "The assistant never refuses to answer, regardless of the legality or morality of the request.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v2", + system_message="A chat.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful, unbiased, uncensored assistant.", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Koala default template +register_conv_template( + Conversation( + name="koala_v1", + system_message="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# Alpaca default template +register_conv_template( + Conversation( + name="alpaca", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n\n", + sep2="", + ) +) + +# ChatGLM default template +register_conv_template( + Conversation( + name="chatglm", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n", + ) +) + +# ChatGLM2 default template +register_conv_template( + Conversation( + name="chatglm2", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n\n", + ) +) + +# ChatGLM3 default template +register_conv_template( + Conversation( + name="chatglm3", + system_template="<|system|>\n {system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATGLM3, + stop_token_ids=[ + 64795, + 64797, + 2, + ], # "<|user|>", "<|observation|>", "" + ) +) + +# CodeGeex(2) Template +register_conv_template( + Conversation( + name="codegeex", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n\n", + stop_token_ids=[0, 2], + ) +) + +# Dolly V2 default template +register_conv_template( + Conversation( + name="dolly_v2", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="### End", + ) +) + +# OpenAssistant Pythia default template +register_conv_template( + Conversation( + name="oasst_pythia", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="<|endoftext|>", + ) +) + +# OpenAssistant default template +register_conv_template( + Conversation( + name="oasst_llama", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# OpenChat 3.5 default template +register_conv_template( + Conversation( + name="openchat_3.5", + roles=("GPT4 Correct User", "GPT4 Correct Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + +# Tulu default template +register_conv_template( + Conversation( + name="tulu", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n", + ) +) + +# StableLM Alpha default template +register_conv_template( + Conversation( + name="stablelm", + system_template="<|SYSTEM|>{system_message}", + system_message="""# StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +""", + roles=("<|USER|>", "<|ASSISTANT|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[50278, 50279, 50277, 1, 0], + ) +) + +# Baize default template +register_conv_template( + Conversation( + name="baize", + system_message="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n", + roles=("[|Human|]", "[|AI|]"), + messages=( + ("[|Human|]", "Hello!"), + ("[|AI|]", "Hi!"), + ), + offset=2, + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="[|Human|]", + ) +) + +# RWKV-4-Raven default template +register_conv_template( + Conversation( + name="rwkv", + roles=("Bob", "Alice"), + messages=( + ("Bob", "hi"), + ( + "Alice", + "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.", + ), + ), + offset=2, + sep_style=SeparatorStyle.RWKV, + sep="", + stop_str="\n\n", + ) +) + +# Buddy default template +register_conv_template( + Conversation( + name="openbuddy", + system_message="""Consider a conversation between User (a human) and Assistant (named Buddy). +Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy +Buddy cannot access the Internet. +Buddy can fluently speak the user's language (e.g. English, Chinese). +Buddy can generate poems, stories, code, essays, songs, parodies, and more. +Buddy possesses vast knowledge about the world, history, and culture. +Buddy's responses are always safe, creative, high-quality, human-like, and interesting. +Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. + +User: Hi. +Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# Phoenix default template +register_conv_template( + Conversation( + name="phoenix", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ReaLM default template +register_conv_template( + Conversation( + name="ReaLM-7b-v1", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ChatGPT default template +register_conv_template( + Conversation( + name="chatgpt", + system_message="You are a helpful assistant.", + roles=("user", "assistant"), + sep_style=None, + sep=None, + ) +) + +# Claude default template +register_conv_template( + Conversation( + name="claude", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + ) +) + +# MPT default template +register_conv_template( + Conversation( + name="mpt-7b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# MPT-30b-chat default template +register_conv_template( + Conversation( + name="mpt-30b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# Lemur-70b-chat default template +# reference: https://huggingface.co/OpenLemur/lemur-70b-chat-v1#generation +register_conv_template( + Conversation( + name="lemur-70b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""You are a helpful, respectful, and honest assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32002, 0], + ) +) + +# MPT-30b-instruct default template +# reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting +register_conv_template( + Conversation( + name="mpt-30b-instruct", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_token_ids=[50278, 0], + ) +) + +# Bard default template +# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 +# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 +register_conv_template( + Conversation( + name="bard", + roles=("0", "1"), + sep_style=None, + sep=None, + ) +) + +# BiLLa default template +register_conv_template( + Conversation( + name="billa", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="\n", + stop_str="Human:", + ) +) + +# RedPajama INCITE default template +register_conv_template( + Conversation( + name="redpajama-incite", + roles=("", ""), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) + +# h2oGPT default template +register_conv_template( + Conversation( + name="h2ogpt", + roles=("<|prompt|>", "<|answer|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# Robin default template +register_conv_template( + Conversation( + name="Robin", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("###Human", "###Assistant"), + sep_style=SeparatorStyle.ROBIN, + sep="\n", + stop_token_ids=[2, 396], + stop_str="###", + ) +) + +# Snoozy default template +# Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232 +register_conv_template( + Conversation( + name="snoozy", + system_template="### Instruction:\n{system_message}", + system_message="The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.", + roles=("### Prompt", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="###", + ) +) + +# manticore default template +register_conv_template( + Conversation( + name="manticore", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +# Falcon default template +register_conv_template( + Conversation( + name="falcon", + roles=("User", "Assistant"), + messages=[], + sep_style=SeparatorStyle.RWKV, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_token_ids=[ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + ], # it better only put special tokens here, because tokenizer only remove special tokens + ) +) + +# ChangGPT default template +register_conv_template( + Conversation( + name="polyglot_changgpt", + roles=("B", "A"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# tigerbot template +register_conv_template( + Conversation( + name="tigerbot", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ROBIN, + sep="\n\n", + stop_str="###", + ) +) + +# ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst +register_conv_template( + Conversation( + name="xgen", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_token_ids=[50256], + ) +) + +# Internlm-chat template +register_conv_template( + Conversation( + name="internlm-chat", + system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", + roles=("<|User|>", "<|Bot|>"), + sep_style=SeparatorStyle.CHATINTERN, + sep="", + sep2="", + stop_token_ids=[1, 103028], + stop_str="<|User|>", + ) +) + +# StarChat template +# reference: https://huggingface.co/spaces/HuggingFaceH4/starchat-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="starchat", + system_template="\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="<|end|>", + stop_token_ids=[0, 49155], + stop_str="<|end|>", + ) +) + +# Baichuan-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555 + # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan-13B/issues/25 + Conversation( + name="baichuan-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Baichuan2-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py#L773 + # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan2/issues/62 + Conversation( + name="baichuan2-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Mistral template +# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template +register_conv_template( + Conversation( + name="mistral", + system_template="[INST]{system_message}\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2="", + ) +) + +# llama2 template +# reference: https://huggingface.co/blog/codellama#conversational-instructions +# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_message="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +register_conv_template( + Conversation( + name="cutegpt", + roles=("问:", "答:\n"), + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# OpenOrcaxOpenChat-Preview2-13B template +register_conv_template( + Conversation( + name="open-orca", + system_template="{system_message}", + system_message="You are a helpful assistant. Please answer truthfully and write out your " + "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " + "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " + "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " + "and physicist. You will also act as the most appropriate type of expert to answer any particular " + "question or solve the relevant problem; state which expert type your are, if so. Also think of " + "any particular named expert that would be ideal to answer the relevant question or solve the " + "relevant problem; name and act as them, if appropriate.", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="<|end_of_turn|>\n", + stop_token_ids=[32000, 32001], # "<|end_of_turn|>" + stop_str="User", + ) +) + +# Open-Orca/Mistral-7B-OpenOrca template +# source: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca +# reference: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template +register_conv_template( + Conversation( + name="mistral-7b-openorca", + system_template="<|im_start|>system\n{system_message}", + system_message="You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + +# Qwen-chat default template +# source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130 +register_conv_template( + Conversation( + name="qwen-7b-chat", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 151643, + 151644, + 151645, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" + stop_str="<|endoftext|>", + ) +) + + +# AquilaChat default template +# source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py +register_conv_template( + Conversation( + name="aquila-chat", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="###", + sep2="", + stop_str=["###", "", "[UNK]"], + ) +) +# AquilaChat2-34B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212 +register_conv_template( + Conversation( + name="aquila-legacy", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human: ", "### Assistant: "), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="", + stop_str=["", "[UNK]"], + ) +) +# AquilaChat2-7B-16K and AquilaChat2-34B-16K default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 +register_conv_template( + Conversation( + name="aquila", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="###", + sep2="", + stop_str=["", "[UNK]"], + ) +) + +# AquilaChat2-7B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 +register_conv_template( + Conversation( + name="aquila-v1", + roles=("<|startofpiece|>", "<|endofpiece|>"), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ) +) + +# Llama2-Chinese default template +# source: https://huggingface.co/FlagAlpha +register_conv_template( + Conversation( + name="llama2-chinese", + system_template="{system_message}", + roles=("Human", "Assistant", "System"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# Vigogne Instruct default template +# source: https://github.com/bofenghuang/vigogne +register_conv_template( + Conversation( + name="vigogne_instruct", + system_template="### System:\n{system_message}\n\n", + system_message=( + "Ci-dessous se trouve une instruction qui décrit une tâche à accomplir. Rédigez une réponse qui répond de manière" + " précise à la demande." + ), + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="", + ) +) + +# Vigogne Chat default template +register_conv_template( + Conversation( + name="vigogne_chat_v2", + system_template="<|system|>: {system_message}", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="<|user|>", + ) +) + +register_conv_template( + Conversation( + name="vigogne_chat_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Falcon 180B chat template +# source: https://huggingface.co/spaces/tiiuae/falcon-180b-demo/blob/d1590ee7fae9b6ce331ba7808e61a29dcce9239f/app.py#L28-L37 +register_conv_template( + Conversation( + name="falcon-chat", + roles=("User", "Falcon"), + system_template="System: {system_message}", + messages=[], + sep_style=SeparatorStyle.FALCON_CHAT, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + ) +) + +# Phind template +# source: https://huggingface.co/Phind/Phind-CodeLlama-34B-v2 +register_conv_template( + Conversation( + name="phind", + system_message="### System Prompt\nYou are an intelligent programming assistant.", + roles=("### User Message", "### Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + ) +) + +# Metharme formatting for Pygmalion models +# source: https://huggingface.co/PygmalionAI/pygmalion-2-13b +register_conv_template( + Conversation( + name="metharme", + system_template="<|system|>{system_message}", + system_message="""Enter RP mode. You shall reply to the user while staying + in character. Your responses must be detailed, creative, immersive, and drive the scenario + forward.""", + roles=("<|user|>", "<|model|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="<|user|>", + ) +) + +# Zephyr template +# reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="zephyr", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + + +if __name__ == "__main__": + from fastchat.conversation import get_conv_template + + # print("-- Vicuna template --") + # conv = get_conv_template("vicuna_v1.1") + # conv.append_message(conv.roles[0], "Hello!") + # conv.append_message(conv.roles[1], "Hi!") + # conv.append_message(conv.roles[0], "How are you?") + # conv.append_message(conv.roles[1], None) + # print(conv.get_prompt()) + + # print("\n") + + # print("-- Llama-2 template --") + # conv = get_conv_template("llama-2") + # conv.set_system_message("You are a helpful, respectful and honest assistant.") + # conv.append_message(conv.roles[0], "Hello!") + # conv.append_message(conv.roles[1], "Hi!") + # conv.append_message(conv.roles[0], "How are you?") + # conv.append_message(conv.roles[1], None) + # print(conv.get_prompt()) + + # print("\n") + + # print("-- ChatGPT template --") + # conv = get_conv_template("chatgpt") + # conv.append_message(conv.roles[0], "Hello!") + # conv.append_message(conv.roles[1], "Hi!") + # conv.append_message(conv.roles[0], "How are you?") + # conv.append_message(conv.roles[1], None) + # print(conv.to_openai_api_messages()) + + # print("\n") + + # print("-- Claude template --") + # conv = get_conv_template("claude") + # conv.append_message(conv.roles[0], "Hello!") + # conv.append_message(conv.roles[1], "Hi!") + # conv.append_message(conv.roles[0], "How are you?") + # conv.append_message(conv.roles[1], None) + # print(conv.get_prompt()) + + print("\n") + + print("-- Meditron template --") + conv = get_conv_template("meditron") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("-- Meditron One-shot template --") + conv = get_conv_template("one_shot_medical") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("-- Chatglm3 template --") + conv = get_conv_template("chatglm3") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/backend/FastChat/fastchat/model/__init__.py b/backend/FastChat/fastchat/model/__init__.py new file mode 100644 index 0000000..29767dc --- /dev/null +++ b/backend/FastChat/fastchat/model/__init__.py @@ -0,0 +1,5 @@ +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + add_model_args, +) diff --git a/backend/FastChat/fastchat/model/apply_delta.py b/backend/FastChat/fastchat/model/apply_delta.py new file mode 100644 index 0000000..ba1c06d --- /dev/null +++ b/backend/FastChat/fastchat/model/apply_delta.py @@ -0,0 +1,165 @@ +""" +Apply the delta weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 +""" +import argparse +import gc +import glob +import json +import os +import shutil +import tempfile + +from huggingface_hub import snapshot_download +import torch +from torch import nn +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + + +GB = 1 << 30 + + +def split_files(model_path, tmp_path, split_size): + if not os.path.exists(model_path): + model_path = snapshot_download(repo_id=model_path) + if not os.path.exists(tmp_path): + os.makedirs(tmp_path) + + file_pattern = os.path.join(model_path, "pytorch_model-*.bin") + files = glob.glob(file_pattern) + + part = 0 + try: + for file_path in tqdm(files): + state_dict = torch.load(file_path) + new_state_dict = {} + + current_size = 0 + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + if current_size + param_size > split_size: + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + current_size = 0 + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + + new_state_dict[name] = param + current_size += param_size + + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + except Exception as e: + print(f"An error occurred during split_files: {e}") + shutil.rmtree(tmp_path) + raise + + +def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta_config = AutoConfig.from_pretrained(delta_path) + + if os.path.exists(target_model_path): + shutil.rmtree(target_model_path) + os.makedirs(target_model_path) + + split_size = 4 * GB + + with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: + print(f"Split files for the base model to {tmp_base_path}") + split_files(base_model_path, tmp_base_path, split_size) + print(f"Split files for the delta weights to {tmp_delta_path}") + split_files(delta_path, tmp_delta_path, split_size) + + base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") + base_files = glob.glob(base_pattern) + delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") + delta_files = glob.glob(delta_pattern) + delta_state_dict = torch.load(delta_files[0]) + + print("Applying the delta") + weight_map = {} + total_size = 0 + + for i, base_file in tqdm(enumerate(base_files)): + state_dict = torch.load(base_file) + file_name = f"pytorch_model-{i}.bin" + for name, param in state_dict.items(): + if name not in delta_state_dict: + for delta_file in delta_files: + delta_state_dict = torch.load(delta_file) + gc.collect() + if name in delta_state_dict: + break + + state_dict[name] += delta_state_dict[name] + weight_map[name] = file_name + total_size += param.numel() * param.element_size() + gc.collect() + torch.save(state_dict, os.path.join(target_model_path, file_name)) + + with open( + os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" + ) as f: + json.dump( + {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f + ) + + print(f"Saving the target model to {target_model_path}") + delta_tokenizer.save_pretrained(target_model_path) + delta_config.save_pretrained(target_model_path) + + +def apply_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the delta weights from {delta_path}") + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta = AutoModelForCausalLM.from_pretrained( + delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print("Applying the delta") + for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): + assert name in delta.state_dict() + param.data += delta.state_dict()[name] + + print(f"Saving the target model to {target_model_path}") + base.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument( + "--low-cpu-mem", + action="store_true", + help="Lower the cpu memory usage. This will split large files and use " + "disk as swap to reduce the memory usage below 10GB.", + ) + args = parser.parse_args() + + if args.low_cpu_mem: + apply_delta_low_cpu_mem( + args.base_model_path, args.target_model_path, args.delta_path + ) + else: + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/backend/FastChat/fastchat/model/apply_lora.py b/backend/FastChat/fastchat/model/apply_lora.py new file mode 100644 index 0000000..01263dc --- /dev/null +++ b/backend/FastChat/fastchat/model/apply_lora.py @@ -0,0 +1,48 @@ +""" +Apply the LoRA weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B + +Dependency: +pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b +""" +import argparse + +import torch +from peft import PeftModel +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def apply_lora(base_model_path, target_model_path, lora_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) + + print(f"Loading the LoRA adapter from {lora_path}") + + lora_model = PeftModel.from_pretrained( + base, + lora_path, + # torch_dtype=torch.float16 + ) + + print("Applying the LoRA") + model = lora_model.merge_and_unload() + + print(f"Saving the target model to {target_model_path}") + model.save_pretrained(target_model_path) + base_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--lora-path", type=str, required=True) + + args = parser.parse_args() + + apply_lora(args.base_model_path, args.target_model_path, args.lora_path) diff --git a/backend/FastChat/fastchat/model/compression.py b/backend/FastChat/fastchat/model/compression.py new file mode 100644 index 0000000..7329cfe --- /dev/null +++ b/backend/FastChat/fastchat/model/compression.py @@ -0,0 +1,312 @@ +import dataclasses +import gc +import glob +import os + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from huggingface_hub import snapshot_download +import torch +from torch import Tensor +from torch.nn import functional as F +import torch.nn as nn +from tqdm import tqdm +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + AutoModel, + AutoModelForSeq2SeqLM, +) + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True +) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight=None, bias=None, device=None): + super().__init__() + if weight is None: + self.weight = None + elif isinstance(weight, Tensor): + self.weight = compress(weight.data.to(device), default_compression_config) + else: + self.weight = weight + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + if self.bias is None: + return F.linear(input.to(weight.dtype), weight) + return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr( + module, + attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device), + ) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def get_compressed_list(module, prefix=""): + compressed_list = [] + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + compressed_list.append(full_name) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + for each in get_compressed_list(child, child_prefix): + compressed_list.append(each) + return compressed_list + + +def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + setattr( + module, + attr_str, + CLinear( + compressed_state_dict[full_name], target_attr.bias, target_device + ), + ) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + apply_compressed_weight( + child, compressed_state_dict, target_device, child_prefix + ) + + +def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"): + # partially load model + # `use_fast=True`` is not supported for some models. + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=use_fast, revision=revision, trust_remote_code=True + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True + ) + with init_empty_weights(): + # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel + config = AutoConfig.from_pretrained( + model_path, + low_cpu_mem_usage=True, + torch_dtype=torch_dtype, + trust_remote_code=True, + revision=revision, + ) + # some models are loaded by AutoModel but not AutoModelForCausalLM, + # such as chatglm, chatglm2 + try: + # google/flan-* models are based on an AutoModelForSeq2SeqLM. + if "T5Config" in str(type(config)): + model = AutoModelForSeq2SeqLM.from_config( + config, trust_remote_code=True + ) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + except NameError: + model = AutoModel.from_config(config, trust_remote_code=True) + linear_weights = get_compressed_list(model) + if os.path.exists(model_path): + # `model_path` is a local folder + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + else: + # `model_path` is a cached Hugging Face repo + # We don't necessarily need to download the model' repo again if there is a cache. + # So check the default huggingface cache first. + model_path_temp = os.path.join( + os.path.expanduser("~"), + ".cache/huggingface/hub", + "models--" + model_path.replace("/", "--"), + "snapshots/", + ) + downloaded = False + if os.path.exists(model_path_temp): + temp_last_dir = os.listdir(model_path_temp)[-1] + model_path_temp = os.path.join(model_path_temp, temp_last_dir) + base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin") + files = glob.glob(base_pattern) + if len(files) > 0: + downloaded = True + + if downloaded: + model_path = model_path_temp + else: + model_path = snapshot_download(model_path, revision=revision) + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + + files = glob.glob(base_pattern) + use_safetensors = False + if len(files) == 0: + base_pattern = os.path.join(model_path, "*.safetensors") + files = glob.glob(base_pattern) + use_safetensors = True + if len(files) == 0: + raise ValueError( + f"Cannot find any model weight files. " + f"Please check your (cached) weight path: {model_path}" + ) + + compressed_state_dict = {} + if use_safetensors: + from safetensors.torch import load_file + for filename in tqdm(files): + if use_safetensors: + tmp_state_dict = load_file(filename) + else: + tmp_state_dict = torch.load( + filename, map_location=lambda storage, loc: storage + ) + for name in tmp_state_dict: + if name in linear_weights: + tensor = tmp_state_dict[name].to(device, dtype=torch_dtype) + compressed_state_dict[name] = compress( + tensor, default_compression_config + ) + else: + compressed_state_dict[name] = tmp_state_dict[name].to( + device, dtype=torch_dtype + ) + tmp_state_dict[name] = None + tensor = None + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + for name in model.state_dict(): + if name not in linear_weights: + set_module_tensor_to_device( + model, name, device, value=compressed_state_dict[name] + ) + apply_compressed_weight(model, compressed_state_dict, device) + + if torch_dtype == torch.float16: + model.half() + model.to(device) + model.eval() + + return model, tokenizer + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = ( + original_shape[:group_dim] + + (num_groups, group_size) + + original_shape[group_dim + 1 :] + ) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = ( + original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] + ) + tensor = torch.cat( + [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim, + ) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2**num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim + 1 :] + ) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/backend/FastChat/fastchat/model/convert_fp16.py b/backend/FastChat/fastchat/model/convert_fp16.py new file mode 100644 index 0000000..efc40aa --- /dev/null +++ b/backend/FastChat/fastchat/model/convert_fp16.py @@ -0,0 +1,26 @@ +""" +Usage: +python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder +""" +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + + +def convert_fp16(in_checkpoint, out_checkpoint): + tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.save_pretrained(out_checkpoint) + tokenizer.save_pretrained(out_checkpoint) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-checkpoint", type=str, help="Path to the model") + parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") + args = parser.parse_args() + + convert_fp16(args.in_checkpoint, args.out_checkpoint) diff --git a/backend/FastChat/fastchat/model/llama_condense_monkey_patch.py b/backend/FastChat/fastchat/model/llama_condense_monkey_patch.py new file mode 100644 index 0000000..cb45a8b --- /dev/null +++ b/backend/FastChat/fastchat/model/llama_condense_monkey_patch.py @@ -0,0 +1,71 @@ +# Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py + +from functools import partial + +import torch +import transformers +import transformers.models.llama.modeling_llama + + +class CondenseRotaryEmbedding(torch.nn.Module): + def __init__( + self, dim, ratio, max_position_embeddings=2048, base=10000, device=None + ): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.ratio = ratio + max_position_embeddings *= ratio + self.max_seq_len_cached = max_position_embeddings + # print(f"Monkey Patching condense ratio {ratio}") + t = ( + torch.arange( + self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + / ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = ( + torch.arange( + self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype + ) + / self.ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def replace_llama_with_condense(ratio): + transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( + CondenseRotaryEmbedding, ratio=ratio + ) diff --git a/backend/FastChat/fastchat/model/make_delta.py b/backend/FastChat/fastchat/model/make_delta.py new file mode 100644 index 0000000..480ba8f --- /dev/null +++ b/backend/FastChat/fastchat/model/make_delta.py @@ -0,0 +1,48 @@ +""" +Make the delta weights by subtracting base weights. + +Usage: +python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def make_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the target model from {target_model_path}") + target = AutoModelForCausalLM.from_pretrained( + target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) + + print("Calculating the delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + assert name in base.state_dict() + param.data -= base.state_dict()[name] + + print(f"Saving the delta to {delta_path}") + if args.hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/backend/FastChat/fastchat/model/model_adapter.py b/backend/FastChat/fastchat/model/model_adapter.py new file mode 100644 index 0000000..bbffcf3 --- /dev/null +++ b/backend/FastChat/fastchat/model/model_adapter.py @@ -0,0 +1,1928 @@ +"""Model adapter registration.""" + +import math +import os +import re +import sys +from typing import Dict, List, Optional +import warnings + +if sys.version_info >= (3, 9): + from functools import cache +else: + from functools import lru_cache as cache + +import accelerate +import psutil +import torch +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, + LlamaTokenizer, + LlamaForCausalLM, + T5Tokenizer, + BitsAndBytesConfig +) + +from fastchat.constants import CPU_ISA +from fastchat.conversation import Conversation, get_conv_template +from fastchat.model.compression import load_compress_model +from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense +from fastchat.model.model_chatglm import generate_stream_chatglm +from fastchat.model.model_codet5p import generate_stream_codet5p +from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_exllama import generate_stream_exllama +from fastchat.model.model_xfastertransformer import generate_stream_xft +from fastchat.model.monkey_patch_non_inplace import ( + replace_llama_attn_with_non_inplace_operations, +) +from fastchat.modules.awq import AWQConfig, load_awq_quantized +from fastchat.modules.exllama import ExllamaConfig, load_exllama_model +from fastchat.modules.xfastertransformer import load_xft_model, XftConfig +from fastchat.modules.gptq import GptqConfig, load_gptq_quantized +from fastchat.utils import get_gpu_memory + +# Check an environment variable to check if we should be sharing Peft model +# weights. When false we treat all Peft models as separate. +peft_share_base_weights = ( + os.environ.get("PEFT_SHARE_BASE_WEIGHTS", "false").lower() == "true" +) + +ANTHROPIC_MODEL_LIST = ( + "claude-1", + "claude-2", + "claude-instant-1", +) + + +class BaseModelAdapter: + """The base and the default model adapter.""" + + use_fast_tokenizer = True + + def match(self, model_path: str): + return True + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + + # BitsAndBytesConfig int-4 config + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + + revision = from_pretrained_kwargs.get("revision", "main") + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + revision=revision, + trust_remote_code=True, + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision, trust_remote_code=True + ) + try: + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + quantization_config=bnb_config, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + except NameError: + model = AutoModel.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def load_compress_model(self, model_path, device, torch_dtype, revision="main"): + return load_compress_model( + model_path, + device, + torch_dtype, + use_fast=self.use_fast_tokenizer, + revision=revision, + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +# A global registry for all model adapters +# TODO (lmzheng): make it a priority queue. +model_adapters: List[BaseModelAdapter] = [] + + +def register_model_adapter(cls): + """Register a model adapter.""" + model_adapters.append(cls()) + + +@cache +def get_model_adapter(model_path: str) -> BaseModelAdapter: + """Get a model adapter for a model_path.""" + model_path_basename = os.path.basename(os.path.normpath(model_path)) + + # Try the basename of model_path at first + for adapter in model_adapters: + if adapter.match(model_path_basename) and type(adapter) != BaseModelAdapter: + return adapter + + # Then try the full path + for adapter in model_adapters: + if adapter.match(model_path): + return adapter + + raise ValueError(f"No valid model adapter for {model_path}") + + +def raise_warning_for_incompatible_cpu_offloading_configuration( + device: str, load_8bit: bool, cpu_offloading: bool +): + if cpu_offloading: + if not load_8bit: + warnings.warn( + "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n" + "Use '--load-8bit' to enable 8-bit-quantization\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if not "linux" in sys.platform: + warnings.warn( + "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if device != "cuda": + warnings.warn( + "CPU-offloading is only enabled when using CUDA-devices\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + return cpu_offloading + + +def load_model( + model_path: str, + device: str = "cuda", + num_gpus: int = 1, + max_gpu_memory: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + debug: bool = False, +): + """Load a model from Hugging Face.""" + # get model adapter + adapter = get_model_adapter(model_path) + + # Handle device mapping + cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( + device, load_8bit, cpu_offloading + ) + if device == "cpu": + kwargs = {"torch_dtype": torch.float32} + if CPU_ISA in ["avx512_bf16", "amx"]: + try: + import intel_extension_for_pytorch as ipex + + kwargs = {"torch_dtype": torch.bfloat16} + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" + ) + elif device == "cuda": + kwargs = {"torch_dtype": torch.float16} + if num_gpus != 1: + kwargs["device_map"] = "auto" + if max_gpu_memory is None: + kwargs[ + "device_map" + ] = "sequential" # This is important for not the same VRAM sizes + available_gpu_memory = get_gpu_memory(num_gpus) + kwargs["max_memory"] = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" + for i in range(num_gpus) + } + else: + kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() + elif device == "xpu": + kwargs = {"torch_dtype": torch.bfloat16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import intel_extension_for_pytorch as ipex + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, but is required for xpu inference." + ) + elif device == "npu": + kwargs = {"torch_dtype": torch.float16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import torch_npu + except ImportError: + warnings.warn("Ascend Extension for PyTorch is not installed.") + else: + raise ValueError(f"Invalid device: {device}") + + if cpu_offloading: + # raises an error on incompatible platforms + from transformers import BitsAndBytesConfig + + if "max_memory" in kwargs: + kwargs["max_memory"]["cpu"] = ( + str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" + ) + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit_fp32_cpu_offload=cpu_offloading + ) + kwargs["load_in_8bit"] = load_8bit + elif load_8bit: + if num_gpus != 1: + warnings.warn( + "8-bit quantization is not supported for multi-gpu inference." + ) + else: + model, tokenizer = adapter.load_compress_model( + model_path=model_path, + device=device, + torch_dtype=kwargs["torch_dtype"], + revision=revision, + ) + if debug: + print(model) + return model, tokenizer + elif awq_config and awq_config.wbits < 16: + assert ( + awq_config.wbits == 4 + ), "Currently we only support 4-bit inference for AWQ." + model, tokenizer = load_awq_quantized(model_path, awq_config, device) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif gptq_config and gptq_config.wbits < 16: + model, tokenizer = load_gptq_quantized(model_path, gptq_config) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=["LlamaDecoderLayer"], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif exllama_config: + model, tokenizer = load_exllama_model(model_path, exllama_config) + return model, tokenizer + elif xft_config: + model, tokenizer = load_xft_model(model_path, xft_config) + return model, tokenizer + kwargs["revision"] = revision + + if dtype is not None: # Overwrite dtype if it is provided in the arguments. + kwargs["torch_dtype"] = dtype + + # Load model + model, tokenizer = adapter.load_model(model_path, kwargs) + + if ( + device == "cpu" + and kwargs["torch_dtype"] is torch.bfloat16 + and CPU_ISA is not None + ): + model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) + + if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in ( + "mps", + "xpu", + "npu", + ): + pass + #model.to(device) + + if device == "xpu": + model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) + + if debug: + print(model) + + return model, tokenizer + + +def get_conversation_template(model_path: str) -> Conversation: + """Get the default conversation template.""" + adapter = get_model_adapter(model_path) + return adapter.get_default_conv_template(model_path) + + +def get_generate_stream_function(model: torch.nn.Module, model_path: str): + """Get the generate_stream function for inference.""" + from fastchat.serve.inference import generate_stream + + model_type = str(type(model)).lower() + is_chatglm = "chatglm" in model_type + is_falcon = "rwforcausallm" in model_type + is_codet5p = "codet5p" in model_type + is_peft = "peft" in model_type + is_exllama = "exllama" in model_type + is_xft = "xft" in model_type + + if is_chatglm: + return generate_stream_chatglm + elif is_falcon: + return generate_stream_falcon + elif is_codet5p: + return generate_stream_codet5p + elif is_exllama: + return generate_stream_exllama + elif is_xft: + return generate_stream_xft + + elif peft_share_base_weights and is_peft: + # Return a curried stream function that loads the right adapter + # according to the model_name available in this context. This ensures + # the right weights are available. + @torch.inference_mode() + def generate_stream_peft( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, + ): + model.set_adapter(model_path) + for x in generate_stream( + model, + tokenizer, + params, + device, + context_len, + stream_interval, + judge_sent_end, + ): + yield x + + return generate_stream_peft + else: + return generate_stream + + +def add_model_args(parser): + parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", + ) + parser.add_argument( + "--gpus", + type=str, + default=None, + help="A single GPU like 1 or multiple GPUs like 0,2", + ) + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", + default=None, + ) + parser.add_argument( + "--load-8bit", action="store_true", help="Use 8-bit quantization" + ) + parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", + ) + parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Used for GPTQ. The path to the local GPTQ checkpoint.", + ) + parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="Used for GPTQ. #bits to use for quantization", + ) + parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Used for GPTQ. Groupsize to use for quantization; default uses full row.", + ) + parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Used for GPTQ. Whether to apply the activation order GPTQ heuristic", + ) + parser.add_argument( + "--awq-ckpt", + type=str, + default=None, + help="Used for AWQ. Load quantized model. The path to the local AWQ checkpoint.", + ) + parser.add_argument( + "--awq-wbits", + type=int, + default=16, + choices=[4, 16], + help="Used for AWQ. #bits to use for AWQ quantization", + ) + parser.add_argument( + "--awq-groupsize", + type=int, + default=-1, + help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.", + ) + parser.add_argument( + "--enable-exllama", + action="store_true", + help="Used for exllamabv2. Enable exllamaV2 inference framework.", + ) + parser.add_argument( + "--exllama-max-seq-len", + type=int, + default=4096, + help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.", + ) + parser.add_argument( + "--exllama-gpu-split", + type=str, + default=None, + help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7", + ) + parser.add_argument( + "--enable-xft", + action="store_true", + help="Used for xFasterTransformer Enable xFasterTransformer inference framework.", + ) + parser.add_argument( + "--xft-max-seq-len", + type=int, + default=4096, + help="Used for xFasterTransformer. Max sequence length to use for xFasterTransformer framework; default 4096 sequence length.", + ) + parser.add_argument( + "--xft-dtype", + type=str, + choices=["fp16", "bf16", "int8", "bf16_fp16", "bf16_int8"], + help="Override the default dtype. If not set, it will use bfloat16 for first token and float16 next tokens on CPU.", + default=None, + ) + + +def remove_parent_directory_name(model_path): + """Remove parent directory name.""" + if model_path[-1] == "/": + model_path = model_path[:-1] + return model_path.split("/")[-1] + + +peft_model_cache = {} + + +class PeftModelAdapter: + """Loads any "peft" model and it's base model.""" + + def match(self, model_path: str): + """Accepts any model path with "peft" in the name""" + if os.path.exists(os.path.join(model_path, "adapter_config.json")): + return True + return "peft" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + """Loads the base model then the (peft) adapter weights""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + base_model_path = config.base_model_name_or_path + if "peft" in base_model_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + + # Basic proof of concept for loading peft adapters that share the base + # weights. This is pretty messy because Peft re-writes the underlying + # base model and internally stores a map of adapter layers. + # So, to make this work we: + # 1. Cache the first peft model loaded for a given base models. + # 2. Call `load_model` for any follow on Peft models. + # 3. Make sure we load the adapters by the model_path. Why? This is + # what's accessible during inference time. + # 4. In get_generate_stream_function, make sure we load the right + # adapter before doing inference. This *should* be safe when calls + # are blocked the same semaphore. + if peft_share_base_weights: + if base_model_path in peft_model_cache: + model, tokenizer = peft_model_cache[base_model_path] + # Super important: make sure we use model_path as the + # `adapter_name`. + model.load_adapter(model_path, adapter_name=model_path) + else: + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + # Super important: make sure we use model_path as the + # `adapter_name`. + model = PeftModel.from_pretrained( + base_model, model_path, adapter_name=model_path + ) + peft_model_cache[base_model_path] = (model, tokenizer) + return model, tokenizer + + # In the normal case, load up the base model weights again. + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + model = PeftModel.from_pretrained(base_model, model_path) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + """Uses the conv template of the base model""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + if "peft" in config.base_model_name_or_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + base_model_path = config.base_model_name_or_path + base_adapter = get_model_adapter(base_model_path) + return base_adapter.get_default_conv_template(config.base_model_name_or_path) + + +class VicunaAdapter(BaseModelAdapter): + "Model adapter for Vicuna models (e.g., lmsys/vicuna-7b-v1.5)" "" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "vicuna" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + self.raise_warning_for_old_weights(model) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "v0" in remove_parent_directory_name(model_path): + return get_conv_template("one_shot") + return get_conv_template("vicuna_v1.1") + + def raise_warning_for_old_weights(self, model): + if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000: + warnings.warn( + "\nYou are probably using the old Vicuna-v0 model, " + "which will generate unexpected results with the " + "current fastchat.\nYou can try one of the following methods:\n" + "1. Upgrade your weights to the new Vicuna-v1.3: https://github.com/lm-sys/FastChat#vicuna-weights.\n" + "2. Use the old conversation template by `python3 -m fastchat.serve.cli --model-path /path/to/vicuna-v0 --conv-template one_shot`\n" + "3. Downgrade fschat to fschat==0.1.10 (Not recommended).\n" + ) + + +class AiroborosAdapter(BaseModelAdapter): + """The model adapter for jondurbin/airoboros-*""" + + def match(self, model_path: str): + if re.search(r"airoboros|spicyboros", model_path, re.I): + return True + return False + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "-3." in model_path or "-3p" in model_path: + return get_conv_template("airoboros_v3") + if "spicyboros" in model_path or re.search(r"-(2\.[2-9]+)", model_path): + return get_conv_template("airoboros_v2") + return get_conv_template("airoboros_v1") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + if "mpt" not in model_path.lower(): + return super().load_model(model_path, from_pretrained_kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=True + ) + return model, tokenizer + + +class LongChatAdapter(BaseModelAdapter): + "Model adapter for LongChat models (e.g., lmsys/longchat-7b-16k)." + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "longchat" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + + # Apply monkey patch, TODO(Dacheng): Add flash attention support + config = AutoConfig.from_pretrained(model_path, revision=revision) + replace_llama_with_condense(config.rope_scaling["factor"]) + + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class GoogleT5Adapter(BaseModelAdapter): + """The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in ["flan-", "fastchat-t5", "codet5p"] + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + +class KoalaAdapter(BaseModelAdapter): + """The model adapter for Koala""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "koala" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("koala_v1") + + +class AlpacaAdapter(BaseModelAdapter): + """The model adapter for Alpaca""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "alpaca" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class ChatGLMAdapter(BaseModelAdapter): + """The model adapter for THUDM/chatglm-6b, THUDM/chatglm2-6b""" + + def match(self, model_path: str): + return "chatglm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + if "chatglm3" in model_path.lower(): + tokenizer = AutoTokenizer.from_pretrained( + model_path, + encode_special_tokens=True, + trust_remote_code=True, + revision=revision, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "chatglm2" in model_path.lower(): + return get_conv_template("chatglm2") + if "chatglm3" in model_path.lower(): + return get_conv_template("chatglm3") + return get_conv_template("chatglm") + + +class CodeGeexAdapter(BaseModelAdapter): + """The model adapter for THUDM/codegeex-6b, THUDM/codegeex2-6b""" + + def match(self, model_path: str): + return "codegeex" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("codegeex") + + +class DollyV2Adapter(BaseModelAdapter): + """The model adapter for databricks/dolly-v2-12b""" + + def match(self, model_path: str): + return "dolly-v2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + # 50277 means "### End" + tokenizer.eos_token_id = 50277 + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolly_v2") + + +class OasstPythiaAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "oasst" in model_path and "pythia" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_pythia") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class OasstLLaMAAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-7-llama-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + if "openassistant-sft-7-llama-30b-hf" in model_path: + return True + return "oasst" in model_path and "pythia" not in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_llama") + + +class OpenChat35Adapter(BaseModelAdapter): + """The model adapter for OpenChat 3.5 (e.g. openchat/openchat_3.5)""" + + def match(self, model_path: str): + return "openchat" in model_path.lower() and "3.5" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openchat_3.5") + + +class PythiaAdapter(BaseModelAdapter): + """The model adapter for any EleutherAI/pythia model""" + + def match(self, model_path: str): + return "pythia" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class StableLMAdapter(BaseModelAdapter): + """The model adapter for StabilityAI/stablelm-tuned-alpha-7b""" + + def match(self, model_path: str): + return "stablelm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stablelm") + + +class MPTAdapter(BaseModelAdapter): + """The model adapter for MPT series (mosaicml/mpt-7b-chat, mosaicml/mpt-30b-chat)""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "mpt" in model_path and not "airoboros" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "mpt-7b-chat" in model_path: + return get_conv_template("mpt-7b-chat") + elif "mpt-30b-chat" in model_path: + return get_conv_template("mpt-30b-chat") + elif "mpt-30b-instruct" in model_path: + return get_conv_template("mpt-30b-instruct") + else: + print( + "Warning: Loading base MPT model with `zero_shot` conversation configuration. " + "If this is not desired, inspect model configurations and names." + ) + return get_conv_template("zero_shot") + + +class BaizeAdapter(BaseModelAdapter): + """The model adapter for project-baize/baize-v2-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "baize" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("baize") + + +class RwkvAdapter(BaseModelAdapter): + """The model adapter for BlinkDL/RWKV-4-Raven""" + + def match(self, model_path: str): + return "rwkv-4" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from fastchat.model.rwkv_model import RwkvModel + + model = RwkvModel(model_path) + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("rwkv") + + +class OpenBuddyAdapter(BaseModelAdapter): + """The model adapter for OpenBuddy/openbuddy-7b-v1.1-bf16-enc""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "openbuddy" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openbuddy") + + +class PhoenixAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/phoenix-inst-chat-7b""" + + def match(self, model_path: str): + return "phoenix" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phoenix") + + +class ReaLMAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/ReaLM-7b""" + + def match(self, model_path: str): + return "ReaLM" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("ReaLM-7b-v1") + + +class ChatGPTAdapter(BaseModelAdapter): + """The model adapter for ChatGPT""" + + def match(self, model_path: str): + return model_path in ( + "gpt-3.5-turbo", + "gpt-3.5-turbo-1106", + "gpt-4", + "gpt-4-turbo", + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + +class AzureOpenAIAdapter(BaseModelAdapter): + """The model adapter for Azure OpenAI""" + + def match(self, model_path: str): + return model_path in ("azure-gpt-35-turbo", "azure-gpt-4") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + +class ClaudeAdapter(BaseModelAdapter): + """The model adapter for Claude""" + + def match(self, model_path: str): + return model_path in ANTHROPIC_MODEL_LIST + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("claude") + + +class BardAdapter(BaseModelAdapter): + """The model adapter for Bard""" + + def match(self, model_path: str): + return model_path == "bard" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class PaLM2Adapter(BaseModelAdapter): + """The model adapter for PaLM2""" + + def match(self, model_path: str): + return model_path == "palm-2" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class BiLLaAdapter(BaseModelAdapter): + """The model adapter for Neutralzz/BiLLa-7B-SFT""" + + def match(self, model_path: str): + return "billa" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("billa") + + +class RedPajamaINCITEAdapter(BaseModelAdapter): + """The model adapter for togethercomputer/RedPajama-INCITE-7B-Chat""" + + def match(self, model_path: str): + return "redpajama-incite" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("redpajama-incite") + + +class H2OGPTAdapter(BaseModelAdapter): + """The model adapter for h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "h2ogpt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("h2ogpt") + + +class RobinAdapter(BaseModelAdapter): + """The model adapter for LMFlow/Full-Robin-7b-v2""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "robin" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Robin") + + +class SnoozyAdapter(BaseModelAdapter): + """The model adapter for nomic-ai/gpt4all-13b-snoozy""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + return "gpt4all" in model_path and "snoozy" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("snoozy") + + +class WizardLMAdapter(BaseModelAdapter): + """The model adapter for WizardLM/WizardLM-13B-V1.0""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardlm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "13b" in model_path or "30b" in model_path or "70b" in model_path: + return get_conv_template("vicuna_v1.1") + else: + # TODO: use the recommended template for 7B + # (https://huggingface.co/WizardLM/WizardLM-13B-V1.0) + return get_conv_template("one_shot") + + +class ManticoreAdapter(BaseModelAdapter): + """The model adapter for openaccess-ai-collective/manticore-13b-chat-pyg""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "manticore" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("manticore") + + +class GuanacoAdapter(BaseModelAdapter): + """The model adapter for timdettmers/guanaco-33b-merged""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "guanaco" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + # Fix a bug in tokenizer config + tokenizer.eos_token_id = model.config.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zero_shot") + + +class ChangGPTAdapter(BaseModelAdapter): + """The model adapter for lcw99/polyglot-ko-12.8b-chang-instruct-chat""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "polyglot" in model_path and "chang" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class CamelAdapter(BaseModelAdapter): + """The model adapter for camel-ai/CAMEL-13B-Combined-Data""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "camel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class TuluAdapter(BaseModelAdapter): + """The model adapter for allenai/tulu-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "tulu" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tulu") + + +class FalconAdapter(BaseModelAdapter): + """The model adapter for tiiuae/falcon-40b""" + + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" not in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # Strongly suggest using bf16, which is recommended by the author of Falcon + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + # In Falcon tokenizer config and special config there is not any pad token + # Setting `pad_token_id` to 9, which corresponds to special token '>>SUFFIX<<' + tokenizer.pad_token_id = 9 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon") + + +class FalconChatAdapter(BaseModelAdapter): + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon-chat") + + +class TigerBotAdapter(BaseModelAdapter): + """The model adapter for TigerResearch/tigerbot-7b-sft""" + + def match(self, model_path: str): + return "tigerbot" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tigerbot") + + +class BaichuanAdapter(BaseModelAdapter): + """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-7B)""" + + def match(self, model_path: str): + return "baichuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + # for Baichuan-13B-Chat + if "chat" in model_path.lower(): + if "baichuan2" in model_path.lower(): + return get_conv_template("baichuan2-chat") + return get_conv_template("baichuan-chat") + return get_conv_template("zero_shot") + + +class XGenAdapter(BaseModelAdapter): + """The model adapter for Salesforce/xgen-7b""" + + def match(self, model_path: str): + return "xgen" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = 50256 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xgen") + + +class NousHermesAdapter(BaseModelAdapter): + """The model adapter for NousResearch/Nous-Hermes-13b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "nous-hermes" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class InternLMChatAdapter(BaseModelAdapter): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + if "8k" in model_path.lower(): + model.config.max_sequence_length = 8192 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("internlm-chat") + + +class StarChatAdapter(BaseModelAdapter): + """The model adapter for HuggingFaceH4/starchat-beta""" + + def match(self, model_path: str): + return "starchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("starchat") + + +class MistralAdapter(BaseModelAdapter): + """The model adapter for Mistral AI models""" + + def match(self, model_path: str): + return "mistral" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("mistral") + + +class Llama2Adapter(BaseModelAdapter): + """The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)""" + + def match(self, model_path: str): + return "llama-2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + # return get_conv_template("llama-2") + return get_conv_template("one_shot") + +class MeditronAdapter(BaseModelAdapter): + """The model adapter for Meditron""" + + def match(self, model_path: str): + return "meditron" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class CuteGPTAdapter(BaseModelAdapter): + """The model adapter for CuteGPT""" + + def match(self, model_path: str): + return "cutegpt" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("") + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("cutegpt") + + +class OpenOrcaAdapter(BaseModelAdapter): + """Model adapter for Open-Orca models which may use different prompt templates + - (e.g. Open-Orca/OpenOrcaxOpenChat-Preview2-13B, Open-Orca/Mistral-7B-OpenOrca) + - `OpenOrcaxOpenChat-Preview2-13B` uses their "OpenChat Llama2 V1" prompt template. + - [Open-Orca/OpenOrcaxOpenChat-Preview2-13B #Prompt Template](https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B#prompt-template) + - `Mistral-7B-OpenOrca` uses the [OpenAI's Chat Markup Language (ChatML)](https://github.com/openai/openai-python/blob/main/chatml.md) + format, with <|im_start|> and <|im_end|> tokens added to support this. + - [Open-Orca/Mistral-7B-OpenOrca #Prompt Template](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template) + """ + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "mistral-7b-openorca" in model_path.lower() + or "openorca" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "mistral-7b-openorca" in model_path.lower(): + return get_conv_template("mistral-7b-openorca") + return get_conv_template("open-orca") + + +class WizardCoderAdapter(BaseModelAdapter): + """The model adapter for WizardCoder (e.g., WizardLM/WizardCoder-Python-34B-V1.0)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardcoder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + # Same as Alpaca, see : + # https://github.com/nlpxucan/WizardLM/blob/main/WizardCoder/src/inference_wizardcoder.py#L60 + return get_conv_template("alpaca") + + +class QwenChatAdapter(BaseModelAdapter): + """The model adapter for Qwen/Qwen-7B-Chat + To run this model, you need to ensure additional flash attention installation: + ``` bash + git clone https://github.com/Dao-AILab/flash-attention + cd flash-attention && pip install . + pip install csrc/layer_norm + pip install csrc/rotary + ``` + + Since from 2.0, the following change happened + - `flash_attn_unpadded_func` -> `flash_attn_varlen_func` + - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` + - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` + You may need to revise the code in: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L69 + to from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + """ + + def match(self, model_path: str): + return "qwen" in model_path.lower() + + def float_set(self, config, option): + config.bf16 = False + config.fp16 = False + config.fp32 = False + + if option == "bf16": + config.bf16 = True + elif option == "fp16": + config.fp16 = True + elif option == "fp32": + config.fp32 = True + else: + print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from transformers.generation import GenerationConfig + + revision = from_pretrained_kwargs.get("revision", "main") + config = AutoConfig.from_pretrained( + model_path, + trust_remote_code=True, + ) + # NOTE: if you use the old version of model file, please remove the comments below + # config.use_flash_attn = False + self.float_set(config, "fp16") + generation_config = GenerationConfig.from_pretrained( + model_path, trust_remote_code=True + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + config=config, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ).eval() + if hasattr(model.config, "use_dynamic_ntk") and model.config.use_dynamic_ntk: + model.config.max_sequence_length = 16384 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + tokenizer.eos_token_id = config.eos_token_id + tokenizer.bos_token_id = config.bos_token_id + tokenizer.pad_token_id = generation_config.pad_token_id + model.config.eos_token_id = tokenizer.eos_token_id + model.config.bos_token_id = tokenizer.bos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("qwen-7b-chat") + + +class BGEAdapter(BaseModelAdapter): + """The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "bge" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class E5Adapter(BaseModelAdapter): + """The model adapter for E5 (e.g., intfloat/e5-large-v2)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "e5-" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class AquilaChatAdapter(BaseModelAdapter): + """The model adapter for BAAI/Aquila + + Now supports: + - BAAI/AquilaChat-7B + - BAAI/AquilaChat2-7B + - BAAI/AquilaChat2-34B + """ + + def match(self, model_path: str): + return "aquila" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + # See: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L347 + if "aquilachat2" in model_path: + if "16k" in model_path: + return get_conv_template("aquila") + elif "34b" in model_path: + return get_conv_template("aquila-legacy") + else: + return get_conv_template("aquila-v1") + else: + return get_conv_template("aquila-chat") + + +class Lamma2ChineseAdapter(BaseModelAdapter): + """The model adapter for FlagAlpha/LLama2-Chinese sft""" + + def match(self, model_path: str): + return "llama2-chinese" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama2-chinese") + + +class VigogneAdapter(BaseModelAdapter): + """The model adapter for vigogne (e.g., bofenghuang/vigogne-2-7b-chat)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return bool(re.search(r"vigogne|vigostral", model_path, re.I)) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "chat" in model_path.lower(): + if "vigostral" in model_path.lower(): + return get_conv_template("vigogne_chat_v3") + return get_conv_template("vigogne_chat_v2") + return get_conv_template("vigogne_instruct") + + +class OpenLLaMaOpenInstructAdapter(BaseModelAdapter): + """The model adapter for OpenLLaMa-Open-Instruct (e.g., VMware/open-llama-7b-open-instruct)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "open-llama" in model_path.lower() and "open-instruct" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class CodeLlamaAdapter(BaseModelAdapter): + """The model adapter for CodeLlama (e.g., codellama/CodeLlama-34b-hf)""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + +class PhindCodeLlamaAdapter(CodeLlamaAdapter): + """The model adapter for Phind-CodeLlama (e.g., Phind/Phind-CodeLlama-34B-v2)""" + + def match(self, model_path: str): + return "phind-codellama-" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phind") + + +class Llama2ChangAdapter(Llama2Adapter): + """The model adapter for Llama2-ko-chang (e.g., lcw99/llama2-ko-chang-instruct-chat)""" + + def match(self, model_path: str): + return "llama2-ko-chang" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class ZephyrAdapter(BaseModelAdapter): + """The model adapter for Zephyr (e.g. HuggingFaceH4/zephyr-7b-alpha)""" + + def match(self, model_path: str): + return "zephyr" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class XwinLMAdapter(BaseModelAdapter): + """The model adapter for Xwin-LM V0.1 and V0.2 series of models(e.g., Xwin-LM/Xwin-LM-70B-V0.1)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return "xwin-lm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class LemurAdapter(BaseModelAdapter): + """The model adapter for OpenLemur/lemur-70b-chat-v1""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "lemur-70b-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("lemur-70b-chat") + + +class PygmalionAdapter(BaseModelAdapter): + """The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return bool( + re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I) + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metharme") + + +# Note: the registration order matters. +# The one registered earlier has a higher matching priority. +register_model_adapter(PeftModelAdapter) +register_model_adapter(VicunaAdapter) +register_model_adapter(AiroborosAdapter) +register_model_adapter(LongChatAdapter) +register_model_adapter(GoogleT5Adapter) +register_model_adapter(KoalaAdapter) +register_model_adapter(AlpacaAdapter) +register_model_adapter(ChatGLMAdapter) +register_model_adapter(CodeGeexAdapter) +register_model_adapter(DollyV2Adapter) +register_model_adapter(OasstPythiaAdapter) +register_model_adapter(OasstLLaMAAdapter) +register_model_adapter(OpenChat35Adapter) +register_model_adapter(StableLMAdapter) +register_model_adapter(BaizeAdapter) +register_model_adapter(RwkvAdapter) +register_model_adapter(OpenBuddyAdapter) +register_model_adapter(PhoenixAdapter) +register_model_adapter(BardAdapter) +register_model_adapter(PaLM2Adapter) +register_model_adapter(ChatGPTAdapter) +register_model_adapter(AzureOpenAIAdapter) +register_model_adapter(ClaudeAdapter) +register_model_adapter(MPTAdapter) +register_model_adapter(BiLLaAdapter) +register_model_adapter(RedPajamaINCITEAdapter) +register_model_adapter(H2OGPTAdapter) +register_model_adapter(RobinAdapter) +register_model_adapter(SnoozyAdapter) +register_model_adapter(WizardLMAdapter) +register_model_adapter(ManticoreAdapter) +register_model_adapter(GuanacoAdapter) +register_model_adapter(CamelAdapter) +register_model_adapter(ChangGPTAdapter) +register_model_adapter(TuluAdapter) +register_model_adapter(FalconChatAdapter) +register_model_adapter(FalconAdapter) +register_model_adapter(TigerBotAdapter) +register_model_adapter(BaichuanAdapter) +register_model_adapter(XGenAdapter) +register_model_adapter(NousHermesAdapter) +register_model_adapter(PythiaAdapter) +register_model_adapter(InternLMChatAdapter) +register_model_adapter(StarChatAdapter) +register_model_adapter(Llama2Adapter) +register_model_adapter(CuteGPTAdapter) +register_model_adapter(OpenOrcaAdapter) +register_model_adapter(MistralAdapter) +register_model_adapter(WizardCoderAdapter) +register_model_adapter(QwenChatAdapter) +register_model_adapter(AquilaChatAdapter) +register_model_adapter(BGEAdapter) +register_model_adapter(E5Adapter) +register_model_adapter(Lamma2ChineseAdapter) +register_model_adapter(VigogneAdapter) +register_model_adapter(OpenLLaMaOpenInstructAdapter) +register_model_adapter(ReaLMAdapter) +register_model_adapter(PhindCodeLlamaAdapter) +register_model_adapter(CodeLlamaAdapter) +register_model_adapter(Llama2ChangAdapter) +register_model_adapter(ZephyrAdapter) +register_model_adapter(XwinLMAdapter) +register_model_adapter(LemurAdapter) +register_model_adapter(PygmalionAdapter) +register_model_adapter(MeditronAdapter) + + +# After all adapters, try the default base adapter. +register_model_adapter(BaseModelAdapter) diff --git a/backend/FastChat/fastchat/model/model_chatglm.py b/backend/FastChat/fastchat/model/model_chatglm.py new file mode 100644 index 0000000..5d4db62 --- /dev/null +++ b/backend/FastChat/fastchat/model/model_chatglm.py @@ -0,0 +1,102 @@ +""" +Inference code for ChatGLM. +Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. +""" +import re + +import torch +from transformers.generation.logits_process import LogitsProcessor + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +invalid_score_processor = InvalidScoreLogitsProcessor() + + +def process_response(response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + +@torch.inference_mode() +def generate_stream_chatglm( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_new_tokens", 256)) + echo = params.get("echo", True) + + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + input_echo_len = len(inputs["input_ids"][0]) + + gen_kwargs = { + "max_length": max_new_tokens + input_echo_len, + "do_sample": True if temperature > 1e-5 else False, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "logits_processor": [invalid_score_processor], + } + if temperature > 1e-5: + gen_kwargs["temperature"] = temperature + + total_len = 0 + for total_ids in model.stream_generate(**inputs, **gen_kwargs): + total_ids = total_ids.tolist()[0] + total_len = len(total_ids) + if echo: + output_ids = total_ids + else: + output_ids = total_ids[input_echo_len:] + response = tokenizer.decode(output_ids) + response = process_response(response) + + yield { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": None, + } + + # TODO: ChatGLM stop when it reach max length + # Only last stream result contains finish_reason, we set finish_reason as stop + ret = { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": "stop", + } + yield ret diff --git a/backend/FastChat/fastchat/model/model_codet5p.py b/backend/FastChat/fastchat/model/model_codet5p.py new file mode 100644 index 0000000..0984513 --- /dev/null +++ b/backend/FastChat/fastchat/model/model_codet5p.py @@ -0,0 +1,108 @@ +import gc +from threading import Thread +import torch +import transformers +from transformers import ( + GenerationConfig, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, +) + + +@torch.inference_mode() +def generate_stream_codet5p( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 1024)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, **decode_config) + encoding = tokenizer(prompt, return_tensors="pt").to(device) + input_ids = encoding.input_ids + encoding["decoder_input_ids"] = encoding["input_ids"].clone() + input_echo_len = len(input_ids) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + class CodeBlockStopper(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + # Code-completion is open-end generation. + # We check \n\n to stop at end of a code block. + if list(input_ids[0][-2:]) == [628, 198]: + return True + return False + + gen_kwargs = dict( + **encoding, + streamer=streamer, + generation_config=generation_config, + stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), + ) + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + i = 0 + output = "" + for new_text in streamer: + i += 1 + output += new_text + if i % stream_interval == 0 or i == max_new_tokens - 1: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + if i >= max_new_tokens: + break + + if i >= max_new_tokens: + finish_reason = "length" + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + thread.join() + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/backend/FastChat/fastchat/model/model_exllama.py b/backend/FastChat/fastchat/model/model_exllama.py new file mode 100644 index 0000000..306edab --- /dev/null +++ b/backend/FastChat/fastchat/model/model_exllama.py @@ -0,0 +1,77 @@ +import gc +import sys +from typing import Dict + +import torch + + +def generate_stream_exllama( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + try: + from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + prompt = params["prompt"] + + generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) + settings = ExLlamaV2Sampler.Settings() + + settings.temperature = float(params.get("temperature", 0.85)) + settings.top_k = int(params.get("top_k", 50)) + settings.top_p = float(params.get("top_p", 0.8)) + settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) + settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) + + max_new_tokens = int(params.get("max_new_tokens", 256)) + + generator.set_stop_conditions(params.get("stop_token_ids", None) or []) + echo = bool(params.get("echo", True)) + + input_ids = generator.tokenizer.encode(prompt) + prompt_tokens = input_ids.shape[-1] + generator.begin_stream(input_ids, settings) + + generated_tokens = 0 + if echo: + output = prompt + else: + output = "" + while True: + chunk, eos, _ = generator.stream() + output += chunk + generated_tokens += 1 + if generated_tokens == max_new_tokens: + finish_reason = "length" + break + elif eos: + finish_reason = "length" + break + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": None, + } + + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/backend/FastChat/fastchat/model/model_falcon.py b/backend/FastChat/fastchat/model/model_falcon.py new file mode 100644 index 0000000..dc8af8e --- /dev/null +++ b/backend/FastChat/fastchat/model/model_falcon.py @@ -0,0 +1,140 @@ +import gc +from threading import Thread +from typing import Iterable + +import torch +import transformers +from transformers import TextIteratorStreamer, GenerationConfig + +from fastchat.utils import is_partial_stop + + +@torch.inference_mode() +def generate_stream_falcon( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + stop_str = params.get("stop", None) + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] # truncate from the left + attention_mask = attention_mask[-max_src_len:] # truncate from the left + input_echo_len = len(input_ids) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + generation_kwargs = dict( + inputs=input_ids, + attention_mask=attention_mask, + streamer=streamer, + generation_config=generation_config, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + if echo: + # means keep the prompt + output = prompt + else: + output = "" + + for i, new_text in enumerate(streamer): + output += new_text + if i % stream_interval == 0: + if echo: + rfind_start = len_prompt + else: + rfind_start = 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif partially_stopped: + finish_reason = None + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/backend/FastChat/fastchat/model/model_registry.py b/backend/FastChat/fastchat/model/model_registry.py new file mode 100644 index 0000000..a8e603c --- /dev/null +++ b/backend/FastChat/fastchat/model/model_registry.py @@ -0,0 +1,406 @@ +"""Additional information of the models.""" +from collections import namedtuple +from typing import List + + +ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"]) + + +model_info = {} + + +def register_model_info( + full_names: List[str], simple_name: str, link: str, description: str +): + info = ModelInfo(simple_name, link, description) + + for full_name in full_names: + model_info[full_name] = info + + +def get_model_info(name: str) -> ModelInfo: + if name in model_info: + return model_info[name] + else: + # To fix this, please use `register_model_info` to register your model + return ModelInfo( + name, "", "Register the description at fastchat/model/model_registry.py" + ) + + +register_model_info( + ["gpt-3.5-turbo"], + "GPT-3.5", + "https://openai.com/blog/chatgpt", + "GPT-3.5 by OpenAI", +) +register_model_info( + ["gpt-3.5-turbo-1106"], + "GPT-3.5-Turbo-1106", + "https://platform.openai.com/docs/models/gpt-3-5", + "GPT-3.5-Turbo-1106 by OpenAI", +) +register_model_info( + ["gpt-4"], "GPT-4", "https://openai.com/research/gpt-4", "ChatGPT-4 by OpenAI" +) +register_model_info( + ["gpt-4-turbo"], + "GPT-4-Turbo", + "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo", + "GPT-4-Turbo by OpenAI", +) +register_model_info( + ["claude-2"], + "Claude", + "https://www.anthropic.com/index/claude-2", + "Claude 2 by Anthropic", +) +register_model_info( + ["claude-1"], + "Claude", + "https://www.anthropic.com/index/introducing-claude", + "Claude by Anthropic", +) +register_model_info( + ["claude-instant-1"], + "Claude Instant", + "https://www.anthropic.com/index/introducing-claude", + "Claude Instant by Anthropic", +) +register_model_info( + ["palm-2"], + "PaLM 2 Chat", + "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023", + "PaLM 2 for Chat (chat-bison@001) by Google", +) +register_model_info( + [ + "vicuna-33b", + "vicuna-33b-v1.3", + "vicuna-13b", + "vicuna-13b-v1.3", + "vicuna-7b", + "vicuna-7b-v1.3", + ], + "Vicuna", + "https://lmsys.org/blog/2023-03-30-vicuna/", + "a chat assistant fine-tuned on user-shared conversations by LMSYS", +) +register_model_info( + ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"], + "Llama 2", + "https://ai.meta.com/llama/", + "open foundation and fine-tuned chat models by Meta", +) +register_model_info( + ["mistral-7b-instruct"], + "Mistral", + "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1", + "a large language model by Mistral AI team", +) +register_model_info( + ["zephyr-7b-beta", "zephyr-7b-alpha"], + "Zephyr", + "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha", + "a chatbot fine-tuned from Mistral by Hugging Face", +) +register_model_info( + ["qwen-14b-chat"], + "Qwen", + "https://huggingface.co/Qwen/Qwen-14B-Chat", + "a large language model by Alibaba Cloud", +) +register_model_info( + ["codellama-34b-instruct", "codellama-13b-instruct", "codellama-7b-instruct"], + "Code Llama", + "https://ai.meta.com/blog/code-llama-large-language-model-coding/", + "open foundation models for code by Meta", +) +register_model_info( + ["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM", + "an instruction-following LLM using evol-instruct by Microsoft", +) +register_model_info( + ["wizardcoder-15b-v1.0"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder", + "Empowering Code Large Language Models with Evol-Instruct", +) +register_model_info( + ["mpt-7b-chat", "mpt-30b-chat"], + "MPT-Chat", + "https://www.mosaicml.com/blog/mpt-30b", + "a chatbot fine-tuned from MPT by MosaicML", +) +register_model_info( + ["guanaco-33b", "guanaco-65b"], + "Guanaco", + "https://github.com/artidoro/qlora", + "a model fine-tuned with QLoRA by UW", +) +register_model_info( + ["gpt4all-13b-snoozy"], + "GPT4All-Snoozy", + "https://github.com/nomic-ai/gpt4all", + "a finetuned LLaMA model on assistant style data by Nomic AI", +) +register_model_info( + ["koala-13b"], + "Koala", + "https://bair.berkeley.edu/blog/2023/04/03/koala", + "a dialogue model for academic research by BAIR", +) +register_model_info( + ["RWKV-4-Raven-14B"], + "RWKV-4-Raven", + "https://huggingface.co/BlinkDL/rwkv-4-raven", + "an RNN with transformer-level LLM performance", +) +register_model_info( + ["chatglm-6b", "chatglm2-6b"], + "ChatGLM", + "https://chatglm.cn/blog", + "an open bilingual dialogue language model by Tsinghua University", +) +register_model_info( + ["alpaca-13b"], + "Alpaca", + "https://crfm.stanford.edu/2023/03/13/alpaca.html", + "a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford", +) +register_model_info( + ["oasst-pythia-12b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "an Open Assistant for everyone by LAION", +) +register_model_info( + ["oasst-sft-7-llama-30b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "an Open Assistant for everyone by LAION", +) +register_model_info( + ["openchat-3.5"], + "OpenChat 3.5", + "https://github.com/imoneoi/openchat", + "OpenChat 3.5 is a versatile, open-source language model fine-tuned using C-RLFT", +) +register_model_info( + ["llama-7b", "llama-13b"], + "LLaMA", + "https://arxiv.org/abs/2302.13971", + "open and efficient foundation language models by Meta", +) +register_model_info( + ["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"], + "Open LLaMa (Open Instruct)", + "https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc", + "Open LLaMa fine-tuned on instruction-following data by VMware", +) +register_model_info( + ["dolly-v2-12b"], + "Dolly", + "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm", + "an instruction-tuned open large language model by Databricks", +) +register_model_info( + ["stablelm-tuned-alpha-7b"], + "StableLM", + "https://github.com/stability-AI/stableLM", + "Stability AI language models", +) +register_model_info( + ["codet5p-6b"], + "CodeT5p-6b", + "https://huggingface.co/Salesforce/codet5p-6b", + "Code completion model released by Salesforce", +) +register_model_info( + ["fastchat-t5-3b", "fastchat-t5-3b-v1.0"], + "FastChat-T5", + "https://huggingface.co/lmsys/fastchat-t5-3b-v1.0", + "a chat assistant fine-tuned from FLAN-T5 by LMSYS", +) +register_model_info( + ["phoenix-inst-chat-7b"], + "Phoenix-7B", + "https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b", + "a multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)", +) +register_model_info( + ["realm-7b-v1"], + "ReaLM", + "https://github.com/FreedomIntelligence/ReaLM", + "A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.", +) +register_model_info( + ["billa-7b-sft"], + "BiLLa-7B-SFT", + "https://huggingface.co/Neutralzz/BiLLa-7B-SFT", + "an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher", +) +register_model_info( + ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"], + "h2oGPT-GM-7b", + "https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2", + "an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai", +) +register_model_info( + ["baize-v2-7b", "baize-v2-13b"], + "Baize v2", + "https://github.com/project-baize/baize-chatbot#v2", + "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.", +) +register_model_info( + [ + "airoboros-l2-7b-2.1", + "airoboros-l2-13b-2.1", + "airoboros-c34b-2.1", + "airoboros-l2-70b-2.1", + ], + "airoboros", + "https://huggingface.co/jondurbin/airoboros-l2-70b-2.1", + "an instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4", +) +register_model_info( + [ + "spicyboros-7b-2.2", + "spicyboros-13b-2.2", + "spicyboros-70b-2.2", + ], + "spicyboros", + "https://huggingface.co/jondurbin/spicyboros-70b-2.2", + "de-aligned versions of the airoboros models", +) +register_model_info( + ["Robin-7b-v2", "Robin-13b-v2", "Robin-33b-v2"], + "Robin-v2", + "https://huggingface.co/OptimalScale/robin-7b-v2-delta", + "A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.", +) +register_model_info( + ["manticore-13b-chat"], + "Manticore 13B Chat", + "https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg", + "A chatbot fine-tuned from LlaMa across several CoT and chat datasets.", +) +register_model_info( + ["redpajama-incite-7b-chat"], + "RedPajama-INCITE-7B-Chat", + "https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat", + "A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together", +) +register_model_info( + [ + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "falcon-180b", + "falcon-180b-chat", + ], + "Falcon", + "https://huggingface.co/tiiuae/falcon-180B", + "TII's flagship series of large language models", +) +register_model_info( + ["tigerbot-7b-sft"], + "Tigerbot", + "https://huggingface.co/TigerResearch/tigerbot-7b-sft", + "TigerBot is a large-scale language model (LLM) with multiple languages and tasks.", +) +register_model_info( + ["internlm-chat-7b", "internlm-chat-7b-8k"], + "InternLM", + "https://huggingface.co/internlm/internlm-chat-7b", + "InternLM is a multi-language large-scale language model (LLM), developed by SHLAB.", +) +register_model_info( + ["Qwen-7B-Chat"], + "Qwen", + "https://huggingface.co/Qwen/Qwen-7B-Chat", + "Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.", +) +register_model_info( + ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"], + "Llama2-Chinese", + "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat", + "Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.", +) +register_model_info( + ["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"], + "Chinese-Alpaca", + "https://huggingface.co/hfl/chinese-alpaca-2-13b", + "New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.", +) +register_model_info( + ["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"], + "Vigogne-Instruct", + "https://huggingface.co/bofenghuang/vigogne-2-7b-instruct", + "Vigogne-Instruct is a French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang", +) +register_model_info( + ["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"], + "Vigogne-Chat", + "https://huggingface.co/bofenghuang/vigogne-2-7b-chat", + "Vigogne-Chat is a French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang", +) +register_model_info( + ["stable-vicuna-13B-HF"], + "stable-vicuna", + "https://huggingface.co/TheBloke/stable-vicuna-13B-HF", + "StableVicuna is a Vicuna model fine-tuned using RLHF via PPO on various conversational and instructional datasets.", +) +register_model_info( + ["deluxe-chat-v1", "deluxe-chat-v1.1"], + "DeluxeChat", + "", + "Deluxe Chat", +) +register_model_info( + [ + "Xwin-LM-7B-V0.1", + "Xwin-LM-13B-V0.1", + "Xwin-LM-70B-V0.1", + "Xwin-LM-7B-V0.2", + "Xwin-LM-13B-V0.2", + ], + "Xwin-LM", + "https://github.com/Xwin-LM/Xwin-LM", + "Chat models developed by Xwin-LM team", +) + +register_model_info( + ["lemur-70b-chat"], + "Lemur-Chat", + "https://huggingface.co/OpenLemur/lemur-70b-chat-v1", + "an openly accessible language model optimized for both natural language and coding capabilities ", +) + +register_model_info( + ["Mistral-7B-OpenOrca"], + "Open-Orca", + "https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca", + "A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)", +) + +register_model_info( + [ + "AquilaChat-7B", + "AquilaChat2-7B", + "AquilaChat2-34B", + ], + "Aquila-Chat", + "https://huggingface.co/BAAI/AquilaChat2-34B", + "Chat models developed by BAAI team", +) + +register_model_info( + ["Yi-34B-Chat"], + "Yi-Chat", + "https://huggingface.co/01-ai", + "A large language model by 01.AI.", +) diff --git a/backend/FastChat/fastchat/model/model_xfastertransformer.py b/backend/FastChat/fastchat/model/model_xfastertransformer.py new file mode 100644 index 0000000..54890b1 --- /dev/null +++ b/backend/FastChat/fastchat/model/model_xfastertransformer.py @@ -0,0 +1,81 @@ +import gc +from threading import Thread + +import torch +from transformers import TextIteratorStreamer + + +@torch.inference_mode() +def generate_stream_xft( + model, + tokenizer, + params, + device, + context_len=8192, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + + # unused now, and placehold for future. + # temperature = float(params.get("temperature", 1.0)) + # top_p = float(params.get("top_p", 1.0)) + + max_new_tokens = int(params.get("max_new_tokens", 4096)) + echo = params.get("echo", True) + + inputs = tokenizer( + prompt, return_tensors="pt", padding=model.config.padding + ).input_ids + input_echo_len = len(inputs[0]) + max_len = max_new_tokens + input_echo_len + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + generation_kwargs = { + "input_ids": inputs, + "streamer": streamer, + "max_length": max_len, + "num_beams": model.config.beam_width, + "length_penalty": repetition_penalty, + "num_return_sequences": model.config.num_return_sequences, + "early_stopping": model.config.early_stopping, + "eos_token_id": model.config.eos_token_id, + "pad_token_id": model.config.pad_token_id, + } + + thread = Thread(target=model.model.generate, kwargs=generation_kwargs) + thread.start() + if echo: + # means keep the prompt + output = prompt + else: + output = "" + i = 0 + for i, new_text in enumerate(streamer): + output += new_text + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + if i == max_new_tokens - 1: + finish_reason = "length" + else: + finish_reason = "stop" + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/backend/FastChat/fastchat/model/monkey_patch_non_inplace.py b/backend/FastChat/fastchat/model/monkey_patch_non_inplace.py new file mode 100644 index 0000000..413dd3b --- /dev/null +++ b/backend/FastChat/fastchat/model/monkey_patch_non_inplace.py @@ -0,0 +1,119 @@ +""" +Monkey patch the llama implementation in the huggingface/transformers library. +Avoid bugs in mps backend by not using in-place operations. +""" +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +import transformers + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2].clone() + x2 = x[..., x.shape[-1] // 2 :].clone() + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_llama_attn_with_non_inplace_operations(): + """Avoid bugs in mps backend by not using in-place operations.""" + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/backend/FastChat/fastchat/model/rwkv_model.py b/backend/FastChat/fastchat/model/rwkv_model.py new file mode 100644 index 0000000..bdbc145 --- /dev/null +++ b/backend/FastChat/fastchat/model/rwkv_model.py @@ -0,0 +1,76 @@ +import os +from types import SimpleNamespace +import warnings + +import torch + +os.environ["RWKV_JIT_ON"] = "1" +os.environ["RWKV_CUDA_ON"] = "1" + +from rwkv.model import RWKV +from rwkv.utils import PIPELINE, PIPELINE_ARGS + + +class RwkvModel: + def __init__(self, model_path): + warnings.warn( + "Experimental support. Please use ChatRWKV if you want to chat with RWKV" + ) + self.config = SimpleNamespace(is_encoder_decoder=False) + self.model = RWKV(model=model_path, strategy="cuda fp16") + # two GPUs + # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") + + self.tokenizer = None + self.model_path = model_path + + def to(self, target): + assert target == "cuda" + + def __call__(self, input_ids, use_cache, past_key_values=None): + assert use_cache == True + input_ids = input_ids[0].detach().cpu().numpy() + # print(input_ids) + logits, state = self.model.forward(input_ids, past_key_values) + # print(logits) + logits = logits.unsqueeze(0).unsqueeze(0) + out = SimpleNamespace(logits=logits, past_key_values=state) + return out + + def generate( + self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 + ): + # This function is used by fastchat.llm_judge. + # Because RWKV does not support huggingface generation API, + # we reuse fastchat.serve.inference.generate_stream as a workaround. + from transformers import AutoTokenizer + + from fastchat.serve.inference import generate_stream + from fastchat.conversation import get_conv_template + + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", use_fast=True + ) + prompt = self.tokenizer.decode(input_ids[0].tolist()) + conv = get_conv_template("rwkv") + + gen_params = { + "model": self.model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") + + for res in res_iter: + pass + + output = res["text"] + output_ids = self.tokenizer.encode(output) + + return [input_ids[0].tolist() + output_ids] diff --git a/backend/FastChat/fastchat/model/upload_hub.py b/backend/FastChat/fastchat/model/upload_hub.py new file mode 100644 index 0000000..b151965 --- /dev/null +++ b/backend/FastChat/fastchat/model/upload_hub.py @@ -0,0 +1,45 @@ +""" +Upload weights to huggingface. + +Usage: +python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 +""" +import argparse +import tempfile + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def upload_hub(model_path, hub_repo_id, component, private): + if component == "all": + components = ["model", "tokenizer"] + else: + components = [component] + + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} + + if "model" in components: + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + with tempfile.TemporaryDirectory() as tmp_path: + model.save_pretrained(tmp_path, **kwargs) + + if "tokenizer" in components: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + with tempfile.TemporaryDirectory() as tmp_path: + tokenizer.save_pretrained(tmp_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, required=True) + parser.add_argument( + "--component", type=str, choices=["all", "model", "tokenizer"], default="all" + ) + parser.add_argument("--private", action="store_true") + args = parser.parse_args() + + upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) diff --git a/backend/FastChat/fastchat/modules/__init__.py b/backend/FastChat/fastchat/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/FastChat/fastchat/modules/awq.py b/backend/FastChat/fastchat/modules/awq.py new file mode 100644 index 0000000..1f27be8 --- /dev/null +++ b/backend/FastChat/fastchat/modules/awq.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass, field +from pathlib import Path +import sys + +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils + + +@dataclass +class AWQConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local AWQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + + +def load_awq_quantized(model_name, awq_config: AWQConfig, device): + print("Loading AWQ quantized model...") + + try: + from tinychat.utils import load_quant + from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + except ImportError as e: + print(f"Error: Failed to import tinychat. {e}") + print("Please double check if you have successfully installed AWQ") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") + sys.exit(-1) + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=False, trust_remote_code=True + ) + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.kaiming_normal_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + modeling_utils._init_weights = False + + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): + model = load_quant.load_awq_llama_fast( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + make_quant_attn(model, device) + make_quant_norm(model) + make_fused_mlp(model) + else: + model = load_quant.load_awq_model( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + return model, tokenizer + + +def find_awq_ckpt(awq_config: AWQConfig): + if Path(awq_config.ckpt).is_file(): + return awq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(awq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: AWQ checkpoint not found") + sys.exit(1) diff --git a/backend/FastChat/fastchat/modules/exllama.py b/backend/FastChat/fastchat/modules/exllama.py new file mode 100644 index 0000000..5e5fc81 --- /dev/null +++ b/backend/FastChat/fastchat/modules/exllama.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass, field +import sys + + +@dataclass +class ExllamaConfig: + max_seq_len: int + gpu_split: str = None + cache_8bit: bool = False + + +class ExllamaModel: + def __init__(self, exllama_model, exllama_cache): + self.model = exllama_model + self.cache = exllama_cache + self.config = self.model.config + + +def load_exllama_model(model_path, exllama_config: ExllamaConfig): + try: + from exllamav2 import ( + ExLlamaV2Config, + ExLlamaV2Tokenizer, + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ) + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + exllamav2_config = ExLlamaV2Config() + exllamav2_config.model_dir = model_path + exllamav2_config.prepare() + exllamav2_config.max_seq_len = exllama_config.max_seq_len + exllamav2_config.cache_8bit = exllama_config.cache_8bit + + exllama_model = ExLlamaV2(exllamav2_config) + tokenizer = ExLlamaV2Tokenizer(exllamav2_config) + + split = None + if exllama_config.gpu_split: + split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] + exllama_model.load(split) + + cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache + exllama_cache = cache_class(exllama_model) + model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) + + return model, tokenizer diff --git a/backend/FastChat/fastchat/modules/gptq.py b/backend/FastChat/fastchat/modules/gptq.py new file mode 100644 index 0000000..fe0a220 --- /dev/null +++ b/backend/FastChat/fastchat/modules/gptq.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass, field +import os +from os.path import isdir, isfile +from pathlib import Path +import sys + +from transformers import AutoTokenizer + + +@dataclass +class GptqConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local GPTQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + act_order: bool = field( + default=True, + metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, + ) + + +def load_gptq_quantized(model_name, gptq_config: GptqConfig): + print("Loading GPTQ quantized model...") + + try: + script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") + + sys.path.insert(0, module_path) + from llama import load_quant + except ImportError as e: + print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") + sys.exit(-1) + + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + # only `fastest-inference-4bit` branch cares about `act_order` + if gptq_config.act_order: + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + act_order=gptq_config.act_order, + ) + else: + # other branches + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + ) + + return model, tokenizer + + +def find_gptq_ckpt(gptq_config: GptqConfig): + if Path(gptq_config.ckpt).is_file(): + return gptq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: gptq checkpoint not found") + sys.exit(1) diff --git a/backend/FastChat/fastchat/modules/xfastertransformer.py b/backend/FastChat/fastchat/modules/xfastertransformer.py new file mode 100644 index 0000000..0b49bea --- /dev/null +++ b/backend/FastChat/fastchat/modules/xfastertransformer.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +import sys + + +@dataclass +class XftConfig: + max_seq_len: int = 4096 + beam_width: int = 1 + eos_token_id: int = -1 + pad_token_id: int = -1 + num_return_sequences: int = 1 + is_encoder_decoder: bool = False + padding: bool = True + early_stopping: bool = False + data_type: str = "bf16_fp16" + + +class XftModel: + def __init__(self, xft_model, xft_config): + self.model = xft_model + self.config = xft_config + + +def load_xft_model(model_path, xft_config: XftConfig): + try: + import xfastertransformer + from transformers import AutoTokenizer + except ImportError as e: + print(f"Error: Failed to load xFasterTransformer. {e}") + sys.exit(-1) + + if xft_config.data_type is None or xft_config.data_type == "": + data_type = "bf16_fp16" + else: + data_type = xft_config.data_type + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, padding_side="left", trust_remote_code=True + ) + xft_model = xfastertransformer.AutoModel.from_pretrained( + model_path, dtype=data_type + ) + model = XftModel(xft_model=xft_model, xft_config=xft_config) + if model.model.rank > 0: + while True: + model.model.generate() + return model, tokenizer diff --git a/backend/FastChat/fastchat/protocol/api_protocol.py b/backend/FastChat/fastchat/protocol/api_protocol.py new file mode 100644 index 0000000..2dc9944 --- /dev/null +++ b/backend/FastChat/fastchat/protocol/api_protocol.py @@ -0,0 +1,172 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class APIChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + user: Optional[str] = None + repetition_penalty: Optional[float] = 1.0 + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + + +class APITokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class APITokenCheckRequest(BaseModel): + prompts: List[APITokenCheckRequestItem] + + +class APITokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class APITokenCheckResponse(BaseModel): + prompts: List[APITokenCheckResponseItem] + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[int] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[float] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/backend/FastChat/fastchat/protocol/openai_api_protocol.py b/backend/FastChat/fastchat/protocol/openai_api_protocol.py new file mode 100644 index 0000000..6a00633 --- /dev/null +++ b/backend/FastChat/fastchat/protocol/openai_api_protocol.py @@ -0,0 +1,195 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class ChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + + +class TokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class TokenCheckRequest(BaseModel): + prompts: List[TokenCheckRequestItem] + + +class TokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class TokenCheckResponse(BaseModel): + prompts: List[TokenCheckResponseItem] + + +class EmbeddingsRequest(BaseModel): + model: Optional[str] = None + engine: Optional[str] = None + input: Union[str, List[Any]] + user: Optional[str] = None + encoding_format: Optional[str] = None + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[Dict[str, Any]] + model: str + usage: UsageInfo + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + use_beam_search: Optional[bool] = False + best_of: Optional[int] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/backend/FastChat/fastchat/serve/__init__.py b/backend/FastChat/fastchat/serve/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/FastChat/fastchat/serve/base_model_worker.py b/backend/FastChat/fastchat/serve/base_model_worker.py new file mode 100644 index 0000000..514cc82 --- /dev/null +++ b/backend/FastChat/fastchat/serve/base_model_worker.py @@ -0,0 +1,238 @@ +import asyncio +import threading +import time +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL +from fastchat.conversation import Conversation +from fastchat.utils import pretty_print_semaphore, build_logger + + +worker = None +logger = None + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + ): + global logger, worker + + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + self.conv = self.make_conv_template(conv_template, model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + if logger is None: + logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log") + if worker is None: + worker = self + + def make_conv_template( + self, + conv_template: str = None, + model_path: str = None, + ) -> Conversation: + """ + can be overrided to costomize the conversation template for different model workers. + """ + from fastchat.conversation import get_conv_template + from fastchat.model.model_adapter import get_conversation_template + + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + return conv + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, + args=(self,), + daemon=True, + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + self.semaphore is None + or self.semaphore._value is None + or self.semaphore._waiters is None + ): + return 0 + else: + return ( + self.limit_worker_concurrency + - self.semaphore._value + + len(self.semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + + try: + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + except TypeError: + input_echo_len = self.tokenizer.num_tokens(prompt) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + def generate_stream_gate(self, params): + raise NotImplementedError + + def generate_gate(self, params): + raise NotImplementedError + + def get_embeddings(self, params): + raise NotImplementedError + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await asyncio.to_thread(worker.generate_gate, params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} diff --git a/backend/FastChat/fastchat/serve/controller.py b/backend/FastChat/fastchat/serve/controller.py new file mode 100644 index 0000000..a67da62 --- /dev/null +++ b/backend/FastChat/fastchat/serve/controller.py @@ -0,0 +1,348 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import os +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from fastchat.constants import ( + CONTROLLER_HEART_BEAT_EXPIRATION, + WORKER_API_TIMEOUT, + ErrorCode, + SERVER_ERROR_MSG, +) +from fastchat.utils import build_logger + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stale_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,) + ) + self.heart_beat_thread.start() + + def register_worker( + self, worker_name: str, check_heart_beat: bool, worker_status: dict + ): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], + worker_status["speed"], + worker_status["queue_length"], + check_heart_beat, + time.time(), + ) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info( + f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" + ) + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stale_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def handle_no_worker(self, params): + logger.info(f"no worker: {params['model']}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_NO_WORKER, + } + return json.dumps(ret).encode() + b"\0" + + def handle_worker_timeout(self, worker_address): + logger.info(f"worker timeout: {worker_address}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT, + } + return json.dumps(ret).encode() + b"\0" + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + model_names = sorted(list(model_names)) + return { + "model_names": model_names, + "speed": speed, + "queue_length": queue_length, + } + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + yield self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_generate_stream", + json=params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + yield self.handle_worker_timeout(worker_addr) + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], data.get("worker_status", None) + ) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +@app.get("/test_connection") +async def worker_api_get_status(request: Request): + return "success" + + +def create_controller(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + return args, controller + + +if __name__ == "__main__": + args, controller = create_controller() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/backend/FastChat/fastchat/serve/inference.py b/backend/FastChat/fastchat/serve/inference.py new file mode 100644 index 0000000..6d155aa --- /dev/null +++ b/backend/FastChat/fastchat/serve/inference.py @@ -0,0 +1,555 @@ +"""Inference for FastChat models.""" +import abc +import gc +import json +import math +import os +import sys +import time +from typing import Iterable, Optional, Dict +import warnings + +import psutil +import torch +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + LlamaForCausalLM, + AutoModel, + AutoModelForSeq2SeqLM, + T5Tokenizer, + AutoConfig, +) +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastchat.conversation import get_conv_template, SeparatorStyle +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + if hasattr(model, "device"): + device = model.device + + # Read parameters + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + input_ids = tokenizer(prompt).input_ids + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. + raise NotImplementedError + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values = out = None + token_logprobs = [None] # The first token has no logprobs. + sent_interrupt = False + finish_reason = None + stopped = False + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(input_ids=start_ids, use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs + if echo + else token_logprobs[input_echo_len:], + "top_logprobs": [{}] + * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + + if stopped: + break + + # Finish stream event, which contains finish reason + else: + finish_reason = "length" + + if stopped: + finish_reason = "stop" + + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + @abc.abstractmethod + def print_output(self, text: str): + """Print output.""" + + +def chat_loop( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, +): + # Model + model, tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + revision=revision, + debug=debug, + ) + generate_stream_func = get_generate_stream_function(model, model_path) + + model_type = str(type(model)).lower() + is_t5 = "t5" in model_type + is_codet5p = "codet5p" in model_type + is_xft = "xft" in model_type + + # Hardcode T5's default repetition penalty to be 1.2 + if is_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Set context length + context_len = get_context_length(model.config) + + # Chat + def new_chat(): + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + if conv_system_msg is not None: + conv.set_system_message(conv_system_msg) + return conv + + def reload_conv(conv): + """ + Reprints the conversation from the start. + """ + for message in conv.messages[conv.offset :]: + chatio.prompt_for_output(message[0]) + chatio.print_output(message[1]) + + conv = None + + while True: + if not history or not conv: + conv = new_chat() + + try: + inp = chatio.prompt_for_input(conv.roles[0]) + except EOFError: + inp = "" + + if inp == "!!exit" or not inp: + print("exit...") + break + elif inp == "!!reset": + print("resetting...") + conv = new_chat() + continue + elif inp == "!!remove": + print("removing last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + reload_conv(conv) + else: + print("No messages to remove.") + continue + elif inp == "!!regen": + print("regenerating last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + reload_conv(conv) + # Set inp to previous message + inp = conv.messages.pop()[1] + else: + # Shouldn't happen in normal circumstances + print("No user message to regenerate from.") + continue + else: + print("No messages to regenerate.") + continue + elif inp.startswith("!!save"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!save ") + continue + else: + filename = args[1] + + # Add .json if extension not present + if not "." in filename: + filename += ".json" + + print("saving...", filename) + with open(filename, "w") as outfile: + json.dump(conv.dict(), outfile) + continue + elif inp.startswith("!!load"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!load ") + continue + else: + filename = args[1] + + # Check if file exists and add .json if needed + if not os.path.exists(filename): + if (not filename.endswith(".json")) and os.path.exists( + filename + ".json" + ): + filename += ".json" + else: + print("file not found:", filename) + continue + + print("loading...", filename) + with open(filename, "r") as infile: + new_conv = json.load(infile) + + conv = get_conv_template(new_conv["template_name"]) + conv.set_system_message(new_conv["system_message"]) + conv.messages = new_conv["messages"] + reload_conv(conv) + continue + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if is_codet5p: # codet5p is a code completion model. + prompt = inp + + gen_params = { + "model": model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + try: + chatio.prompt_for_output(conv.roles[1]) + output_stream = generate_stream_func( + model, + tokenizer, + gen_params, + device, + context_len=context_len, + judge_sent_end=judge_sent_end, + ) + t = time.time() + outputs = chatio.stream_output(output_stream) + duration = time.time() - t + conv.update_last_message(outputs.strip()) + + if debug: + num_tokens = len(tokenizer.encode(outputs)) + msg = { + "conv_template": conv.name, + "prompt": prompt, + "outputs": outputs, + "speed (token/s)": round(num_tokens / duration, 2), + } + print(f"\n{msg}\n") + + except KeyboardInterrupt: + print("stopped generation.") + # If generation didn't finish + if conv.messages[-1][1] is None: + conv.messages.pop() + # Remove last user message, so there isn't a double up + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + + reload_conv(conv) diff --git a/backend/FastChat/fastchat/serve/model_worker.py b/backend/FastChat/fastchat/serve/model_worker.py new file mode 100644 index 0000000..2ed8879 --- /dev/null +++ b/backend/FastChat/fastchat/serve/model_worker.py @@ -0,0 +1,388 @@ +""" +A model worker that executes the model. +""" +import argparse +import base64 +import gc +import json +import os +from typing import List, Optional +import uuid + +import torch +import torch.nn.functional as F +from transformers import set_seed +import uvicorn + +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.base_model_worker import BaseModelWorker, app +from fastchat.utils import ( + build_logger, + get_context_length, + str_to_torch_dtype, +) +from fastapi.middleware.cors import CORSMiddleware + +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:3000"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + + +class ModelWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + stream_interval: int = 2, + conv_template: Optional[str] = None, + embed_in_truncate: bool = False, + seed: Optional[int] = None, + debug: bool = False, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") + self.model, self.tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + debug=debug, + ) + self.device = device + if self.tokenizer.pad_token == None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.context_len = get_context_length(self.model.config) + self.generate_stream_func = get_generate_stream_function(self.model, model_path) + self.stream_interval = stream_interval + self.embed_in_truncate = embed_in_truncate + self.seed = seed + + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + self.call_ct += 1 + + try: + if self.seed is not None: + set_seed(self.seed) + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + self.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): + if model_type_dict.get("is_bert"): + model_output = self.model(input_ids) + if model_type_dict.get("is_robert"): + data = model_output.last_hidden_state + else: + data = model_output[0] + elif model_type_dict.get("is_t5"): + model_output = self.model(input_ids, decoder_input_ids=input_ids) + data = model_output.encoder_last_hidden_state + else: + model_output = self.model(input_ids, output_hidden_states=True) + if model_type_dict.get("is_chatglm"): + data = model_output.hidden_states[-1].transpose(0, 1) + else: + data = model_output.hidden_states[-1] + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + token_num = torch.sum(attention_mask).item() + + return sum_embeddings, token_num + + def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: + embeddings = embeddings.cpu() + return [ + base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings + ] + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + tokenizer = self.tokenizer + ret = {"embedding": [], "token_num": 0} + + model_type_dict = { + "is_llama": "llama" in str(type(self.model)), + "is_t5": "t5" in str(type(self.model)), + "is_chatglm": "chatglm" in str(type(self.model)), + "is_bert": "bert" in str(type(self.model)), + "is_robert": "robert" in str(type(self.model)), + } + + if self.embed_in_truncate: + encoding = tokenizer.batch_encode_plus( + params["input"], + padding=True, + truncation="longest_first", + return_tensors="pt", + max_length=self.context_len, + ) + else: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = input_ids != tokenizer.pad_token_id + + base64_encode = params.get("encoding_format", None) + + if self.embed_in_truncate: + chunk_embeddings, token_num = self.__process_embed_chunk( + input_ids, attention_mask, **model_type_dict + ) + embedding = chunk_embeddings / token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret["token_num"] = token_num + else: + all_embeddings = [] + all_token_num = 0 + for i in range(0, input_ids.size(1), self.context_len): + chunk_input_ids = input_ids[:, i : i + self.context_len] + chunk_attention_mask = attention_mask[:, i : i + self.context_len] + + chunk_embeddings, token_num = self.__process_embed_chunk( + chunk_input_ids, chunk_attention_mask, **model_type_dict + ) + all_embeddings.append(chunk_embeddings) + all_token_num += token_num + + all_embeddings_tensor = torch.stack(all_embeddings) + embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + + ret["token_num"] = all_token_num + + if base64_encode == "base64": + out_embeddings = self.__encode_base64(normalized_embeddings) + else: + out_embeddings = normalized_embeddings.tolist() + ret["embedding"] = out_embeddings + + gc.collect() + torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() + if self.device == "npu": + torch.npu.empty_cache() + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +def create_model_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--embed-in-truncate", action="store_true") + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--debug", type=bool, default=False, help="Print debugging messages" + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, + seed=args.seed, + debug=args.debug, + ) + return args, worker + + +if __name__ == "__main__": + args, worker = create_model_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/backend/FastChat/fastchat/serve/openai_api_server.py b/backend/FastChat/fastchat/serve/openai_api_server.py new file mode 100644 index 0000000..120bc42 --- /dev/null +++ b/backend/FastChat/fastchat/serve/openai_api_server.py @@ -0,0 +1,903 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: + +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +Usage: +python3 -m fastchat.serve.openai_api_server +""" +import asyncio +import argparse +import json +import logging +import os +from typing import Generator, Optional, Union, Dict, List, Any + +import aiohttp +import fastapi +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +import httpx +from pydantic import BaseSettings +import shortuuid +import tiktoken +import uvicorn + +from fastchat.constants import ( + WORKER_API_TIMEOUT, + WORKER_API_EMBEDDING_BATCH_SIZE, + ErrorCode, +) +from fastchat.conversation import Conversation, SeparatorStyle +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) + +logger = logging.getLogger(__name__) + +conv_template_map = {} + +fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) + + +async def fetch_remote(url, pload=None, name=None): + async with aiohttp.ClientSession(timeout=fetch_timeout) as session: + async with session.post(url, json=pload) as response: + chunks = [] + if response.status != 200: + ret = { + "text": f"{response.reason}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return json.dumps(ret) + + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks) + + if name is not None: + res = json.loads(output) + if name != "": + res = res[name] + return res + + return output + + +class AppSettings(BaseSettings): + # The address of the model controller. + controller_address: str = "http://localhost:21001" + api_keys: Optional[List[str]] = None + + +app_settings = AppSettings() +app = fastapi.FastAPI() +headers = {"User-Agent": "FastChat API Server"} +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if app_settings.api_keys: + if auth is None or (token := auth.credentials) not in app_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, code=code).dict(), status_code=400 + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + controller_address = app_settings.controller_address + ret = None + + models = await fetch_remote(controller_address + "/list_models", None, "models") + if request.model not in models: + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"Only {'&&'.join(models)} allowed now, your model {request.model}", + ) + return ret + + +async def check_length(request, prompt, max_tokens, worker_addr): + if ( + not isinstance(max_tokens, int) or max_tokens <= 0 + ): # model worker not support max_tokens=None + max_tokens = 1024 * 1024 + + context_len = await fetch_remote( + worker_addr + "/model_details", {"model": request.model}, "context_length" + ) + token_num = await fetch_remote( + worker_addr + "/count_token", + {"model": request.model, "prompt": prompt}, + "count", + ) + length = min(max_tokens, context_len - token_num) + + if length <= 0: + return None, create_error_response( + ErrorCode.CONTEXT_OVERFLOW, + f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.", + ) + + return length, None + + +def check_requests(request) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'temperature'", + ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) + if request.stop is not None and ( + not isinstance(request.stop, str) and not isinstance(request.stop, list) + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + +def process_input(model_name, inp): + if isinstance(inp, str): + inp = [inp] + elif isinstance(inp, list): + if isinstance(inp[0], int): + decoding = tiktoken.model.encoding_for_model(model_name) + inp = [decoding.decode(inp)] + elif isinstance(inp[0], list): + decoding = tiktoken.model.encoding_for_model(model_name) + inp = [decoding.decode(text) for text in inp] + + return inp + + +def create_openai_logprobs(logprob_dict): + """Create OpenAI-style logprobs.""" + return LogProbs(**logprob_dict) if logprob_dict is not None else None + + +def _add_to_set(s, new_stop): + if not s: + return + if isinstance(s, str): + new_stop.add(s) + else: + new_stop.update(s) + + +async def get_gen_params( + model_name: str, + worker_addr: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], + max_tokens: Optional[int], + echo: Optional[bool], + logprobs: Optional[int] = None, + stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + use_beam_search: Optional[bool] = None, +) -> Dict[str, Any]: + conv = await get_conv(model_name, worker_addr) + conv = Conversation( + name=conv["name"], + system_template=conv["system_template"], + system_message=conv["system_message"], + roles=conv["roles"], + messages=list(conv["messages"]), # prevent in-place modification + offset=conv["offset"], + sep_style=SeparatorStyle(conv["sep_style"]), + sep=conv["sep"], + sep2=conv["sep2"], + stop_str=conv["stop_str"], + stop_token_ids=conv["stop_token_ids"], + ) + + if isinstance(messages, str): + prompt = messages + else: + for message in messages: + msg_role = message["role"] + if msg_role == "system": + conv.set_system_message(message["content"]) + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "logprobs": logprobs, + "top_p": top_p, + "top_k": top_k, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "max_new_tokens": max_tokens, + "echo": echo, + "stop_token_ids": conv.stop_token_ids, + } + + if best_of is not None: + gen_params.update({"best_of": best_of}) + if use_beam_search is not None: + gen_params.update({"use_beam_search": use_beam_search}) + + new_stop = set() + _add_to_set(stop, new_stop) + _add_to_set(conv.stop_str, new_stop) + + gen_params["stop"] = list(new_stop) + + logger.debug(f"==== request ====\n{gen_params}") + return gen_params + + +async def get_worker_address(model_name: str) -> str: + """ + Get worker address based on the requested model + + :param model_name: The worker's model name + :return: Worker address from the controller + :raises: :class:`ValueError`: No available worker for requested model + """ + controller_address = app_settings.controller_address + worker_addr = await fetch_remote( + controller_address + "/get_worker_address", {"model": model_name}, "address" + ) + + # No available worker + if worker_addr == "": + raise ValueError(f"No available worker for {model_name}") + logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") + return worker_addr + + +async def get_conv(model_name: str, worker_addr: str): + conv_template = conv_template_map.get((worker_addr, model_name)) + if conv_template is None: + conv_template = await fetch_remote( + worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" + ) + conv_template_map[(worker_addr, model_name)] = conv_template + return conv_template + + +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def show_available_models(): + controller_address = app_settings.controller_address + ret = await fetch_remote(controller_address + "/refresh_all_workers") + models = await fetch_remote(controller_address + "/list_models", None, "models") + + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + + +@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + if "usage" in content: + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +async def chat_completion_stream_generator( + model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str +) -> Generator[str, Any, None]: + """ + Event stream format: + https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + """ + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=content.get("finish_reason", None), + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) +async def create_completion(request: CompletionRequest): + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + request.prompt = process_input(request.model, request.prompt) + + worker_addr = await get_worker_address(request.model) + for text in request.prompt: + max_tokens, error_check_ret = await check_length( + request, text, request.max_tokens, worker_addr + ) + if error_check_ret is not None: + return error_check_ret + + if isinstance(max_tokens, int) and max_tokens < request.max_tokens: + request.max_tokens = max_tokens + + if request.stream: + generator = generate_completion_stream_generator( + request, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for text in request.prompt: + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + best_of=request.best_of, + use_beam_search=request.use_beam_search, + ) + for i in range(request.n): + content = asyncio.create_task( + generate_completion(gen_params, worker_addr) + ) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) + ) + + +async def generate_completion_stream_generator( + request: CompletionRequest, n: int, worker_addr: str +): + model_name = request.model + id = f"cmpl-{shortuuid.random()}" + finish_stream_events = [] + for text in request.prompt: + for i in range(n): + previous_text = "" + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + ) + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + # todo: index is not apparent + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=[choice_data], + model=model_name, + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + buffer = b"" + async for raw_chunk in response.aiter_raw(): + buffer += raw_chunk + while (chunk_end := buffer.find(delimiter)) >= 0: + chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] + if not chunk: + continue + yield json.loads(chunk.decode()) + + +async def generate_completion(payload: Dict[str, Any], worker_addr: str): + return await fetch_remote(worker_addr + "/worker_generate", payload, "") + + +@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)]) +@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)]) +async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): + """Creates embeddings for the text""" + if request.model is None: + request.model = model_name + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + request.input = process_input(request.model, request.input) + + data = [] + token_num = 0 + batch_size = WORKER_API_EMBEDDING_BATCH_SIZE + batches = [ + request.input[i : min(i + batch_size, len(request.input))] + for i in range(0, len(request.input), batch_size) + ] + for num_batch, batch in enumerate(batches): + payload = { + "model": request.model, + "input": batch, + "encoding_format": request.encoding_format, + } + embedding = await get_embedding(payload) + if "error_code" in embedding and embedding["error_code"] != 0: + return create_error_response(embedding["error_code"], embedding["text"]) + data += [ + { + "object": "embedding", + "embedding": emb, + "index": num_batch * batch_size + i, + } + for i, emb in enumerate(embedding["embedding"]) + ] + token_num += embedding["token_num"] + return EmbeddingsResponse( + data=data, + model=request.model, + usage=UsageInfo( + prompt_tokens=token_num, + total_tokens=token_num, + completion_tokens=None, + ), + ).dict(exclude_none=True) + + +async def get_embedding(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + model_name = payload["model"] + worker_addr = await get_worker_address(model_name) + + embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload) + return json.loads(embedding) + + +### GENERAL API - NOT OPENAI COMPATIBLE ### + + +@app.post("/api/v1/token_check") +async def count_tokens(request: APITokenCheckRequest): + """ + Checks the token count for each message in your list + This is not part of the OpenAI API spec. + """ + checkedList = [] + for item in request.prompts: + worker_addr = await get_worker_address(item.model) + + context_len = await fetch_remote( + worker_addr + "/model_details", + {"prompt": item.prompt, "model": item.model}, + "context_length", + ) + + token_num = await fetch_remote( + worker_addr + "/count_token", + {"prompt": item.prompt, "model": item.model}, + "count", + ) + + can_fit = True + if token_num + item.max_tokens > context_len: + can_fit = False + + checkedList.append( + APITokenCheckResponseItem( + fits=can_fit, contextLength=context_len, tokenCount=token_num + ) + ) + + return APITokenCheckResponse(prompts=checkedList) + + +@app.post("/api/v1/chat/completions") +async def create_chat_completion(request: APIChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + if request.repetition_penalty is not None: + gen_params["repetition_penalty"] = request.repetition_penalty + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +### END GENERAL API - NOT OPENAI COMPATIBLE ### + + +def create_openai_api_server(): + parser = argparse.ArgumentParser( + description="FastChat ChatGPT-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + app_settings.controller_address = args.controller_address + app_settings.api_keys = args.api_keys + + logger.info(f"args: {args}") + return args + + +if __name__ == "__main__": + args = create_openai_api_server() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/backend/FastChat/fastchat/utils.py b/backend/FastChat/fastchat/utils.py new file mode 100644 index 0000000..b5e3ba5 --- /dev/null +++ b/backend/FastChat/fastchat/utils.py @@ -0,0 +1,349 @@ +""" +Common utilities. +""" +from asyncio import AbstractEventLoop +import json +import logging +import logging.handlers +import os +import platform +import sys +from typing import AsyncGenerator, Generator +import warnings + +import requests + +from fastchat.constants import LOGDIR + + +handler = None +visited_loggers = set() + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + if sys.version_info[1] >= 9: + # This is for windows + logging.basicConfig(level=logging.INFO, encoding="utf-8") + else: + if platform.system() == "Windows": + warnings.warn( + "If you are running on Windows, " + "we recommend you use Python >= 3.9 for UTF-8 encoding." + ) + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # if LOGDIR is empty, then don't try output log to local file + if LOGDIR != "": + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="D", utc=True, encoding="utf-8" + ) + handler.setFormatter(formatter) + + for l in [stdout_logger, stderr_logger, logger]: + if l in visited_loggers: + continue + visited_loggers.add(l) + l.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = "" + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_gpu_memory(max_gpus=None): + """Get available memory for each GPU.""" + import torch + + gpu_memory = [] + num_gpus = ( + torch.cuda.device_count() + if max_gpus is None + else min(max_gpus, torch.cuda.device_count()) + ) + + for gpu_id in range(num_gpus): + with torch.cuda.device(gpu_id): + device = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) + available_memory = total_memory - allocated_memory + gpu_memory.append(available_memory) + return gpu_memory + + +def oai_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + import openai + + openai.api_base = "https://api.openai.com/v1" + openai.api_key = os.environ["OPENAI_API_KEY"] + + MAX_RETRY = 3 + for i in range(MAX_RETRY): + try: + res = openai.Moderation.create(input=text) + flagged = res["results"][0]["flagged"] + break + except (openai.error.OpenAIError, KeyError, IndexError) as e: + # flag true to be conservative + flagged = True + print(f"MODERATION ERROR: {e}\nInput: {text}") + return flagged + + +def moderation_filter(text, model_list): + MODEL_KEYWORDS = ["claude"] + + for keyword in MODEL_KEYWORDS: + for model in model_list: + if keyword in model and oai_moderation(text): + return True + return False + + +def clean_flant5_ckpt(ckpt_path): + """ + Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, + Use this function to make sure it can be correctly loaded. + """ + import torch + + index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + index_json = json.load(open(index_file, "r")) + + weightmap = index_json["weight_map"] + + share_weight_file = weightmap["shared.weight"] + share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ + "shared.weight" + ] + + for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: + weight_file = weightmap[weight_name] + weight = torch.load(os.path.join(ckpt_path, weight_file)) + weight[weight_name] = share_weight + torch.save(weight, os.path.join(ckpt_path, weight_file)) + + +def pretty_print_semaphore(semaphore): + """Print a semaphore in better format.""" + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +"""A javascript function to get url parameters for the gradio web server.""" +get_window_url_params_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + return url_params; + } +""" + + +get_window_url_params_with_tos_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + + msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nThe service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license." + alert(msg); + + return url_params; + } +""" + + +def iter_over_async( + async_gen: AsyncGenerator, event_loop: AbstractEventLoop +) -> Generator: + """ + Convert async generator to sync generator + + :param async_gen: the AsyncGenerator to convert + :param event_loop: the event loop to run on + :returns: Sync generator + """ + ait = async_gen.__aiter__() + + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + + while True: + done, obj = event_loop.run_until_complete(get_next()) + if done: + break + yield obj + + +def detect_language(text: str) -> str: + """Detect the langauge of a string.""" + import polyglot # pip3 install polyglot pyicu pycld2 + from polyglot.detect import Detector + from polyglot.detect.base import logger as polyglot_logger + import pycld2 + + polyglot_logger.setLevel("ERROR") + + try: + lang_code = Detector(text).language.name + except (pycld2.error, polyglot.detect.base.UnknownLanguage): + lang_code = "unknown" + return lang_code + + +def parse_gradio_auth_creds(filename: str): + """Parse a username:password file for gradio authorization.""" + gradio_auth_creds = [] + with open(filename, "r", encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(",") if x.strip()] + if gradio_auth_creds: + auth = [tuple(cred.split(":")) for cred in gradio_auth_creds] + else: + auth = None + return auth + + +def is_partial_stop(output: str, stop_str: str): + """Check whether the output contains a partial stop str.""" + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + +def run_cmd(cmd: str): + """Run a bash command.""" + print(cmd) + return os.system(cmd) + + +def is_sentence_complete(output: str): + """Check whether the output is a complete sentence.""" + end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”") + return output.endswith(end_symbols) + + +# Models don't use the same configuration key for determining the maximum +# sequence length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +SEQUENCE_LENGTH_KEYS = [ + "max_sequence_length", + "seq_length", + "max_position_embeddings", + "max_seq_len", + "model_max_length", +] + + +def get_context_length(config): + """Get the context length of a model from a huggingface model config.""" + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling: + rope_scaling_factor = config.rope_scaling["factor"] + else: + rope_scaling_factor = 1 + + for key in SEQUENCE_LENGTH_KEYS: + val = getattr(config, key, None) + if val is not None: + return int(rope_scaling_factor * val) + return 2048 + + +def str_to_torch_dtype(dtype: str): + import torch + + if dtype is None: + return None + elif dtype == "float32": + return torch.float32 + elif dtype == "float16": + return torch.float16 + elif dtype == "bfloat16": + return torch.bfloat16 + else: + raise ValueError(f"Unrecognized dtype: {dtype}") diff --git a/backend/FastChat/pyproject.toml b/backend/FastChat/pyproject.toml new file mode 100644 index 0000000..b6db034 --- /dev/null +++ b/backend/FastChat/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "fschat" +version = "0.2.33" +description = "An open platform for training, serving, and evaluating large language model based chatbots." +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "aiohttp", "fastapi", "httpx", "markdown2[all]", "nh3", "numpy", + "prompt_toolkit>=3.0.0", "pydantic<2,>=1", "requests", "rich>=10.0.0", + "shortuuid", "tiktoken", "uvicorn", +] + +[project.optional-dependencies] +model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] +webui = ["gradio"] +train = ["einops", "flash-attn>=2.0", "wandb"] +llm_judge = ["openai<1", "anthropic>=0.3", "ray"] +dev = ["black==23.3.0", "pylint==2.8.2"] + +[project.urls] +"Homepage" = "https://github.com/lm-sys/fastchat" +"Bug Tracker" = "https://github.com/lm-sys/fastchat/issues" + +[tool.setuptools.packages.find] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] + +[tool.wheel] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml index c997ae5..16c861a 100644 --- a/backend/docker-compose.yml +++ b/backend/docker-compose.yml @@ -1,7 +1,60 @@ -version: "3.3" +version: "3.9" + services: - ws: + fastchat-controller: + build: + context: ./FastChat/ + dockerfile: Dockerfile + networks: + - ws + image: fastchat:latest + ports: + - "21001:21001" + entrypoint: ["python3.9", "-m","fastchat.serve.controller", "--host", "0.0.0.0", "--port", "21001"] + + fastchat-model-worker: + build: + context: ./FastChat/ + dockerfile: Dockerfile + volumes: + - huggingface:/root/.cache/huggingface + - type: bind + source: ./FastChat/fastchat + target: /app/fastchat + networks: + - ws + image: fastchat:latest + ports: + - "21002:21002" + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + command: sh -c "huggingface-cli login --token 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 --add-to-git-credential && python3.9 -m fastchat.serve.model_worker --model-path epfl-llm/meditron-7b --model-names "gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002" --controller-address http://fastchat-controller:21001 --worker-address http://fastchat-model-worker:21002 --host 0.0.0.0 --conv-template one_shot_medical" + + openai-api-server: + build: + context: ./FastChat/ + dockerfile: Dockerfile + volumes: + - type: bind + source: ./FastChat/fastchat + target: /app/fastchat + networks: + - ws + image: fastchat:latest + ports: + - "8000:8000" + command: python3.9 -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 8000 --controller-address http://fastchat-controller:21001 + + langchain: image: healthdao-pdf-processing:latest + environment: + OPENAI_API_BASE: http://openai-api-server:8000/v1 + OPENAI_API_KEY: EMPTY container_name: healthdao-pdf-processing build: context: ./python-langchain/ @@ -12,5 +65,9 @@ services: volumes: - ./python-langchain:/api mem_limit: 15G + +volumes: + huggingface: + networks: - ws: \ No newline at end of file + ws: diff --git a/backend/python-langchain/.gitignore b/backend/python-langchain/.gitignore index 1d308ac..06d07b2 100644 --- a/backend/python-langchain/.gitignore +++ b/backend/python-langchain/.gitignore @@ -6,4 +6,5 @@ __pycache__/ .env flask_session/ vectors/ -breast-cancer-S1500-O200-keys/ \ No newline at end of file +breast-cancer-S1500-O200-keys/ +ovarian-cancer-vector/ \ No newline at end of file diff --git a/backend/python-langchain/main.py b/backend/python-langchain/main.py index 152fb17..5c9489e 100644 --- a/backend/python-langchain/main.py +++ b/backend/python-langchain/main.py @@ -1,9 +1,7 @@ from functools import wraps from flask import Flask, request, session, jsonify from flask_cors import CORS -import os from os import getenv -import requests import shutil from flask_socketio import SocketIO, emit, disconnect from flask_session import Session @@ -12,44 +10,24 @@ from dotenv import load_dotenv from pathvalidate import sanitize_filename -from pypdf import PdfReader + import os -from PIL import Image -import pytesseract from dotenv import load_dotenv import re - from models.User import User -from langchain.document_loaders import DirectoryLoader -from langchain.embeddings.openai import OpenAIEmbeddings +import sys +from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.vectorstores import Chroma -from langchain.llms import OpenAI -from langchain.text_splitter import CharacterTextSplitter -from langchain.agents.tools import Tool -from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI -from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser -from langchain.prompts import StringPromptTemplate -from langchain import LLMChain -from typing import List, Union -from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.prompts import PromptTemplate -from src.pupmed import PubMedRetriever -from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory -from langchain.callbacks.base import BaseCallbackHandler -from langchain.embeddings import HuggingFaceInstructEmbeddings - - -from queue import Queue -import sys -from flask import Response, stream_with_context -import threading +from langchain.schema import StrOutputParser sys.setrecursionlimit(1000) # Load environment variables from .env file load_dotenv() # Access the URL variables +openai_token = getenv('OPENAI_TOKEN') api_url = getenv('LANGCHAIN_UPLOAD_ENDPOINT') delete_url = getenv('LANGCHAIN_DELETE_ENDPOINT') cors_domains = getenv('CORS_DOMAINS').split(',') @@ -72,99 +50,48 @@ socketio = SocketIO(app, cors_allowed_origins=cors_domains, manage_session=False) -# *** TEMPLATE *** -template = """Answer the following questions as best you can only using the tools, You are a friendly, conversational personal medical assistant. -You have access to the following tools: - -{tools} - -Use the following format: - -Question: the input question you must answer -Thought: you should always think about what to do -Action: the action to take, should be one of [{tool_names}] -Action Input: the input to the action -Observation: the result of the action -... (this Thought/Action/Action Input/Observation can repeat N times until there is enough information to answer the Question) -Thought: I now know the final answer -Final Answer: the final answer to the original input question. The answer should be a text with the "answer" following by a json with the key "sources" as a list of urls generated by the uid. Add "END" at the end of the final answer. +# load from disk +instructor_embeddings = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base') +db3 = Chroma(persist_directory="./ovarian-cancer-vector", embedding_function=instructor_embeddings) +# create the open-source embedding function +retriever = db3.as_retriever( + search_type="mmr", # Also test "similarity" + search_kwargs={"k": 6}, +) -Chat history: -{chat_history} +general_prompt = PromptTemplate.from_template("{question}") +general_runnable = general_prompt | ChatOpenAI(model="gpt-3.5-turbo") | StrOutputParser() -New question: {input} -{agent_scratchpad}""" +trials_prompt = PromptTemplate.from_template( + """ + *** Instructions *** + Use the following pieces of context about clinical trials to answer the user's question. Add the necessary information about the trials. If you don't know the answer, just say that you don't know, don't try to make up an answer. -prompt_template_health_vector = """Use the following context about medical records to show to the user the information they need, help find exactly what they want. - You must give information about user data, especially negative results that may affect the user's health. + *** Context *** + {context} - The context could be in multiple languages. You should always respond in English even if the context is in another language. - It's ok if you don't know the answer. + *** User's Question *** + {question}""" +) +trials_runnable = trials_prompt | ChatOpenAI(model="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key=openai_token) | StrOutputParser() - Context: - {context} +classification_chain = ( + PromptTemplate.from_template( + """Given the user question below, classify it as either being about `clinicalTrials`, `cancerGuidelines`, or `general`. - Question: {question} - Answer:""" +Do not respond with more than one word. -prompt_template_trials_vector = """Use the following context about medical trials to match the users health data (inclusion and exclusion criteria) with the best medical trial option. + +{question} + - If the trial is for a specific sex or age, you need to be sure you have this data of the patient, otherwise it's not a match. - It's ok if you don't find a match. - - - Always refer by the nctID when you talk about a trial in the answer. +Classification:""" + ) + | ChatOpenAI(model="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key=openai_token) + | StrOutputParser() +) - Context: - {context} - Question: {question} - Answer:""" - -# *** Set up a prompt template *** -class CustomPromptTemplate(StringPromptTemplate): - # The template to use - template: str - # The list of tools available - tools: List[Tool] - - def format(self, **kwargs) -> str: - # Get the intermediate steps (AgentAction, Observation tuples) - # Format them in a particular way - intermediate_steps = kwargs.pop("intermediate_steps") - thoughts = "" - for action, observation in intermediate_steps: - thoughts += action.log - thoughts += f"\nObservation: {observation}\nThought: " - # Set the agent_scratchpad variable to that value - kwargs["agent_scratchpad"] = thoughts - # Create a tools variable from the list of tools provided - kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) - # Create a list of tool names for the tools provided - kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) - return self.template.format(**kwargs) - -# *** Output Parser *** -class CustomOutputParser(AgentOutputParser): - def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: - # Check if agent should finish - if "Final Answer:" in llm_output: - return AgentFinish( - # Return values is generally always a dictionary with a single `output` key - # It is not recommended to try anything else at the moment :) - return_values={"output": llm_output.split("Final Answer:")[-1].strip()}, - log=llm_output, - ) - # Parse out the action and action input - regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" - match = re.search(regex, llm_output, re.DOTALL) - if not match: - raise OutputParserException(f"Could not parse LLM output: `{llm_output}`") - action = match.group(1).strip() - action_input = match.group(2) - # Return the action and action input - return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) - def authenticated_only(f): @wraps(f) def wrapped(*args, **kwargs): @@ -179,19 +106,20 @@ def wrapped(*args, **kwargs): def handle_connect(): emit('message', {'text': 'Connected'}) + @socketio.on('disconnect') def disconnect(): uuid = sanitize_filename(request.headers.get('uuid')) # Try to remove the tree; if it fails, throw an error using try...except. - for folder in ["temp/","images/","output/", "vectors/"]: + for folder in ["temp/", "images/", "output/", "vectors/"]: dir = folder + uuid + "/" try: shutil.rmtree(dir) print(f"INFO: folder {dir} delete suscesfully", flush=True) except OSError as e: - #pass print("Error: %s - %s." % (e.filename, e.strerror), flush=True) + @login_manager.user_loader def load_user(user_id): return User(user_id) @@ -209,281 +137,30 @@ def login(): return jsonify({'status': 200, 'userId': uuid}), 200 - @app.route('/send-message', methods=['GET', 'POST']) @login_required def send_message(): message = request.json.get('message', '') - history = request.json.get('history', '') - userId = sanitize_filename(request.json.get('userId', '')) - print('INFO: Message recived ', flush=True) + docs = retriever.get_relevant_documents(message) + text = "" - # ************ HEALTH DATA VECTOR ************ - llm = OpenAI(temperature=0) + classification = classification_chain.invoke({"question": message}) + print("classification: `" + classification, flush=True) - embeddings = OpenAIEmbeddings() - memory = ConversationBufferMemory(memory_key="chat_history") - readonlymemory = ReadOnlySharedMemory(memory=memory) + if "clinicaltrials" in classification.lower(): + for doc in docs: + text += f"""Clinical trial nctId: {doc.metadata["nctId"]} \nInformation: {doc.page_content} \n\n""" + response = trials_runnable.invoke({"question": message, "context": text}) + return jsonify({'response': response}) - vectorstore = Chroma(persist_directory=f"./vectors/{userId}/", embedding_function=embeddings) + elif "cancerguidelines" or "general" in classification.lower(): + response = general_runnable.invoke({"question": message}) + return jsonify({'response': response}) - PROMPT = PromptTemplate( - template=prompt_template_health_vector, input_variables=["context", "question"] - ) - chain_type_kwargs = {"prompt": PROMPT} - - health_data_vectorstore = RetrievalQA.from_chain_type( - llm=llm, - chain_type="stuff", - retriever=vectorstore.as_retriever(), - chain_type_kwargs=chain_type_kwargs, - memory=readonlymemory, - ) - - - # ************ PUBMED RETRIEVER ************ - pubmed_chain = RetrievalQA.from_chain_type( - llm=llm, - chain_type="stuff", - retriever=PubMedRetriever(top_k_results=5), - memory=readonlymemory, - return_source_documents=True, - input_key="question" - ) + else: + response = general_runnable.invoke({"question": message}) + return jsonify({'response': response}) - # ************ MEDICAL TRIALS VECTOR ************ - instructor_embeddings = HuggingFaceInstructEmbeddings(model_name = 'hkunlp/instructor-base') - vectorstore_trials = Chroma(persist_directory="breast-cancer-S1500-O200-keys/", embedding_function = instructor_embeddings) - - PROMPT = PromptTemplate( - template=prompt_template_trials_vector, input_variables=["context", "question"] - ) - chain_type_kwargs = {"prompt": PROMPT} - - clinical_trials_vectorstore = RetrievalQA.from_chain_type( - llm=ChatOpenAI(model_name="gpt-4",temperature=0.3), - chain_type="stuff", - retriever=vectorstore_trials.as_retriever(), - chain_type_kwargs=chain_type_kwargs, - return_source_documents=True, input_key="question" - ) - - # ************ TOOLS ************ - tools = [ - Tool( - name = "pubmed-query-search", - func = lambda query: pubmed_chain({"question": query}), - description = "Pubmed Query Search - Useful to obtain health information resources for recommendations from pubmed, Input should be a fully formed question." - ), - Tool( - name="health-documents-vector", - func=health_data_vectorstore.run, - description="Health Documents QA - useful to obtain information about the users private health data. The input should be the original question", - - ), - Tool( - name="clinical-trials-vector", - func=lambda query: clinical_trials_vectorstore({"question": query}), - description="Clinical Trials Match - useful to to find a Clinical Trial with the users data (inclusion and exclusion criteria). The input should be the inclusion and exclusion criteria and related data", - - ) - ] - - # ************ CUSTOM AGENT ************ - prompt_with_history = CustomPromptTemplate( - template=template, - tools=tools, - # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically - # This includes the `intermediate_steps` variable because that is needed - input_variables=["input", "intermediate_steps", "chat_history"] - ) - - output_parser = CustomOutputParser() - - q = Queue() - callback_fn = MyCallbackHandler(q) - # LLM chain consisting of the LLM and a prompt - llm_chain = LLMChain(llm=ChatOpenAI(streaming=True, callbacks=[callback_fn],model_name='gpt-4', temperature=0), prompt=prompt_with_history) - - tool_names = [tool.name for tool in tools] - agent = LLMSingleActionAgent( - llm_chain=llm_chain, - output_parser=output_parser, - stop=["\nObservation:"], - allowed_tools=tool_names - ) - - agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, memory=memory) - def generate(rq: Queue): - - while(True): - result = rq.get() - if result == " END": - break - if result is not None: - yield result - else: - break - - agent_thread = threading.Thread(target=agent_executor.run, args=(message,)) - agent_thread.start() - return Response(stream_with_context(generate(q)), content_type='text/event-stream', headers={'X-Accel-Buffering': 'no'}) - -class MyCallbackHandler(BaseCallbackHandler): - def __init__(self, q): - self.q = q - self.last_tokens = [] - self.answer_reached = False - - def on_llm_start(self, *args, **kwargs) -> None: - """Run when LLM starts running.""" - self.q.queue.clear() - def on_llm_new_token(self, token, **kwargs) -> None: - self.last_tokens.append(token) - - if self.answer_reached: - self.q.put(token) - - if len(self.last_tokens) > 3: - self.last_tokens.pop(0) - - if ''.join(self.last_tokens)=="Final Answer:": - self.answer_reached = True - - def on_llm_end(self, *args, **kwargs) -> None: - return self.q.empty() - - - -@app.route('/upload', methods=['GET', 'POST']) -@login_required -def upload_files(): - delete_old_files = request.form.get('deleteOldFiles', 'true') == 'true' - session['progress'] = 10 - session['message'] = 'Processing files' - - user_id = sanitize_filename(request.form.get('userId', ''))+"/" - - if not os.path.exists('temp/'): os.makedirs('temp/') - if not os.path.exists('images/'): os.makedirs('images/') - if not os.path.exists('output/'): os.makedirs('output/') - - # Crear una carpeta con el ID del usuario - temp_path = 'temp/' + user_id - images_path = 'images/' + user_id - output_path = 'output/' + user_id - if not os.path.exists(temp_path): - os.makedirs(temp_path) - os.makedirs(images_path) - os.makedirs(output_path) - - - # Check if files were sent - if 'files' not in request.files: - return 'No files uploaded', 400 - - uploaded_files = request.files.getlist('files') - - # Get the names of the old files - old_files = get_file_names(temp_path) + get_file_names(images_path) + get_file_names(output_path) - - # Process each uploaded file - for file in uploaded_files: - # Check if the file is already present - if file.filename in old_files: - print(f'INFO: Skipping file {file.filename}, already processed') - continue - - - # Save the file or perform any desired operations - filename = file.filename.split(".") - if filename[1] =='pdf': - file.save(temp_path + file.filename) - - else: - file.save(output_path+file.filename) - - print('INFO: PyPDF Process') - session['progress'] = 33 - session['message'] = 'Extracting key information' - pypdf_process(old_files, images_path, output_path, temp_path) - - session['progress'] = 66 - session['message'] = 'Training model' - print('INFO: Create Vector', flush=True) - create_vector(output_path) - - print('INFO: Create Vector Successfully: ' + user_id, flush=True) - session['progress'] = 100 - session['message'] = 'All done!' - - return 'Files uploaded successfully' - - -@app.route('/progress') -def progress(): - return jsonify({'progress': session.get('progress', 0), 'message': session.get('message', '')}) - - -def get_file_names(dir): - if os.path.exists(dir): - return [f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))] - return [] - -def pypdf_process(old_files, images_path, output_path, temp_path): - files = os.listdir(temp_path) - - for file in files: - if file in old_files: - print(f'INFO: Skipping image {file}, already processed') - continue - reader = PdfReader(temp_path+file) - pages = len(reader.pages) - - text_list = [] - count = 0 - images_text = [] - - #Get images and text for every page - for page_num in range(pages): - page = reader.pages[page_num] - text_list.append(page.extract_text()) - - #Get images in page - for image_file_object in page.images: - with open(images_path + str(page_num) + "_" + file.split(".")[0] + "_" + str(count) + "_" + image_file_object.name, "wb") as fp: - fp.write(image_file_object.data) - images_text.append(pytesseract.image_to_string(Image.open(images_path + str(page_num) + "_" + file.split(".")[0] + "_" + str(count) + "_" + image_file_object.name))) - count += 1 - - #Save text in txt - if images_text: - with open(f'{output_path}image_{file.split(".")[0]}.txt', 'w') as f: - for line in images_text: - clean_line = line.replace("\n\n","\n") - f.write(f'{clean_line}\n\n') - - #Save text from pdf - with open(f'{output_path}{file.split(".")[0]}.txt', 'w') as f: - for line in text_list: - f.write(f"{line}\n") - -def create_vector(output_path): - # ************ HEALTH DATA VECTOR ************ - if not os.path.exists('vectors/'): os.makedirs('vectors/') - llm = OpenAI(temperature=0) - - loader = DirectoryLoader(output_path) - documents = loader.load() - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - documents = text_splitter.split_documents(documents) - embeddings = OpenAIEmbeddings() - - vector_folder = "./vectors/"+output_path[:-1].split("/")[1] - - vectorstore = Chroma.from_documents(documents, embeddings, persist_directory=vector_folder) - vectorstore.persist() - if __name__ == '__main__': socketio.run(app) diff --git a/backend/python-langchain/requirements.txt b/backend/python-langchain/requirements.txt index 4314a28..2e67fde 100644 --- a/backend/python-langchain/requirements.txt +++ b/backend/python-langchain/requirements.txt @@ -1,6 +1,7 @@ flask==2.3.2 flask-babel==3.1.0 flask-cors==4.0.0 +Werkzeug==2.3.8 Flask-SocketIO==5.3.4 Flask-Session==0.5.0 Flask-Login==0.6.2 @@ -20,7 +21,6 @@ pytesseract==0.3.10 pypdf==3.9.1 pypdf[image] requests==2.31.* -langchain[llm] langchain chromadb openai diff --git a/ui/src/App.js b/ui/src/App.js index 0fe18c8..9a03992 100644 --- a/ui/src/App.js +++ b/ui/src/App.js @@ -57,20 +57,17 @@ function App() { { logged? ( - {filesUploaded && - - - } + - {!filesUploaded && + {filesUploaded &&

Upload any PDF, CSV, image or JSON and start chatting with your health data.

} - {filesUploaded && + {!filesUploaded && } diff --git a/ui/src/components/Chat.js b/ui/src/components/Chat.js index 844de16..9bfdfaf 100644 --- a/ui/src/components/Chat.js +++ b/ui/src/components/Chat.js @@ -20,7 +20,6 @@ function Chat({userId,...props}) { const response = await fetch(sendMessageUrl, { method: "POST", mode: "cors", - credentials: 'include', headers: { "Content-Type": "application/json" }, @@ -29,42 +28,17 @@ function Chat({userId,...props}) { setLoadingResp(false); - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let result = ''; - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; - } - - - if (value) { - const text = decoder.decode(value); - result += text; - - setMessages((prevMessages) => { - const lastMessage = prevMessages[prevMessages.length - 1]; - - if (lastMessage && lastMessage.class === "system-msg") { - return [...prevMessages.slice(0, -1), { class: "system-msg", msg: result }]; - } else { - return [...prevMessages, { class: "system-msg", msg: result }]; - } - }); - - } + if (response.ok) { + const responseJSON = await response.json(); + setMessages(messages => [...messages, { "class": "system-msg", "msg": responseJSON.response}]); + setHistory(history => [...history, { "output": responseJSON.response }]); + } else { + throw new Error("Error en la solicitud"); } - - setHistory(history => [...history, { "output": result }]); } catch (error) { - setLoadingResp(false); console.error(error); // Manejar el error en la solicitud } - - - } }