diff --git a/.gitignore b/.gitignore index ce68ca4e4..0d770b1ec 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,10 @@ Thumbs.db # Sphinx build _build/ + +# environment +.env + +# temp doc files +doc/readme.md +doc/xml diff --git a/.gitmodules b/.gitmodules index 455a55143..7dea81d36 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,18 @@ [submodule "SimpleOT"] - path = SimpleOT + path = deps/SimpleOT url = https://github.com/mkskeller/SimpleOT [submodule "mpir"] - path = mpir - url = git://github.com/wbhart/mpir.git + path = deps/mpir + url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion [submodule "simde"] - path = simde + path = deps/simde url = https://github.com/simd-everywhere/simde +[submodule "deps/libOTe"] + path = deps/libOTe + url = https://github.com/mkskeller/softspoken-implementation +[submodule "deps/SimplestOT_C"] + path = deps/SimplestOT_C + url = https://github.com/mkskeller/SimplestOT_C diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 5ca1360ab..0fe11a0f1 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) : } cout << "Compiler: " << prev << endl; P = new PlainPlayer(N, 0); + Share::MAC_Check::setup(*P); if (argc > 4) threshold = atoi(argv[4]); cout << "Threshold for multi-threaded evaluation: " << threshold << endl; @@ -259,7 +260,6 @@ ProgramParty::~ProgramParty() reset(); if (P) { - cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl; delete P; } delete[] eval_threads; @@ -281,6 +281,7 @@ FakeProgramParty::~FakeProgramParty() cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes() << " GB" << endl; #endif + Share::MAC_Check::teardown(); } void FakeProgramParty::_compute_prfs_outputs(Key* keys) diff --git a/BMR/RealGarbleWire.h b/BMR/RealGarbleWire.h index 9fa2dc521..115d0bcaa 100644 --- a/BMR/RealGarbleWire.h +++ b/BMR/RealGarbleWire.h @@ -48,8 +48,6 @@ class RealGarbleWire : public PRFRegister static void inputbvec(GC::Processor>& processor, ProcessorBase& input_processor, const vector& args); - RealGarbleWire(const Register& reg) : PRFRegister(reg) {} - void garble(PRFOutputs& prf_output, const RealGarbleWire& left, const RealGarbleWire& right); diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 55adcbfb1..c9e31fc65 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -110,7 +110,7 @@ void RealGarbleWire::inputbvec( { GarbleInputter inputter; processor.inputbvec(inputter, input_processor, args, - inputter.party.P->my_num()); + *inputter.party.P); } template @@ -175,7 +175,7 @@ void GarbleInputter::exchange() assert(party.P != 0); assert(party.MC != 0); auto& protocol = party.shared_proc->protocol; - protocol.init_mul(party.shared_proc); + protocol.init_mul(); for (auto& tuple : tuples) protocol.prepare_mul(tuple.first->mask, T::constant(1, party.P->my_num(), party.mac_key) diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 22438deb4..421394146 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -28,7 +28,7 @@ RealProgramParty* RealProgramParty::singleton = 0; template RealProgramParty::RealProgramParty(int argc, const char** argv) : - garble_processor(garble_machine), dummy_proc({{}, 0}) + garble_processor(garble_machine), dummy_proc({}, 0) { assert(singleton == 0); singleton = this; @@ -64,7 +64,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : online_opts = {opt, argc, argv, 1000}; else online_opts = {opt, argc, argv}; - assert(not online_opts.interactive); online_opts.finalize(opt, argc, argv); this->load(online_opts.progname); @@ -97,8 +96,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : if (online_opts.live_prep) { mac_key.randomize(prng); - if (T::needs_ot) - BaseMachine::s().ot_setups.push_back({*P, true}); prep = new typename T::LivePrep(0, usage); } else @@ -107,13 +104,13 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : prep = new Sub_Data_Files(N, prep_dir, usage); } + T::MAC_Check::setup(*P); MC = new typename T::MAC_Check(mac_key); garble_processor.reset(program); - this->processor.open_input_file(N.my_num(), 0); + this->processor.open_input_file(N.my_num(), 0, online_opts.cmd_private_input_file); + this->processor.setup_redirection(P->my_num(), 0, online_opts, this->processor.out); - T::bit_type::mac_key_type::init_field(); - GC::ShareThread share_thread(N, online_opts, *P, 0, usage); shared_proc = new SubProcessor(dummy_proc, *MC, *prep, *P); auto& inputter = shared_proc->input; @@ -157,7 +154,10 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : while (next != GC::DONE_BREAK); MC->Check(*P); - data_sent = P->comm_stats.total_data() + prep->data_sent(); + data_sent = P->total_comm().sent; + + if (online_opts.verbose) + P->total_comm().print(); this->machine.write_memory(this->N.my_num()); } @@ -175,7 +175,8 @@ void RealProgramParty::garble() garble_jobs.clear(); garble_inputter->reset_all(*P); auto& protocol = *garble_protocol; - protocol.init_mul(shared_proc); + protocol.init(*prep, shared_proc->MC); + protocol.init_mul(); next = this->first_phase(program, garble_processor, this->garble_machine); @@ -183,7 +184,8 @@ void RealProgramParty::garble() protocol.exchange(); typename T::Protocol second_protocol(*P); - second_protocol.init_mul(shared_proc); + second_protocol.init(*prep, shared_proc->MC); + second_protocol.init_mul(); for (auto& job : garble_jobs) job.middle_round(*this, second_protocol); @@ -215,6 +217,7 @@ RealProgramParty::~RealProgramParty() delete garble_inputter; delete garble_protocol; cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; + T::MAC_Check::teardown(); } template diff --git a/BMR/Register.cpp b/BMR/Register.cpp index 77c7af577..8f75f13a7 100644 --- a/BMR/Register.cpp +++ b/BMR/Register.cpp @@ -568,6 +568,7 @@ void EvalRegister::inputb(GC::Processor >& processor, octetStream& my_os = oss[party.get_id() - 1]; vector accesses; InputArgList a(args); + processor.complexity += a.n_input_bits(); for (auto x : a) { accesses.push_back({x , processor}); diff --git a/BMR/Register.h b/BMR/Register.h index 886155d79..2085eb25a 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -23,6 +23,7 @@ using namespace std; #include "Tools/PointerVector.h" #include "Tools/Bundle.h" #include "Tools/SwitchableOutput.h" +#include "Processor/Instruction.h" //#define PAD_TO_8(n) (n+8-n%8) #define PAD_TO_8(n) (n) @@ -151,7 +152,7 @@ class Register { * for pipelining matters. */ - Register(int n_parties); + Register(); void init(int n_parties); void init(int rfd, int n_parties); @@ -234,6 +235,9 @@ class Phase template static void ands(T& processor, const vector& args) { processor.ands(args); } template + static void andrsvec(T& processor, const vector& args) + { processor.andrsvec(args); } + template static void xors(T& processor, const vector& args) { processor.xors(args); } template static void inputb(T& processor, const vector& args) { processor.input(args); } @@ -243,6 +247,9 @@ class Phase template static T get_input(int from, GC::Processor& processor, int n_bits) { return T::input(from, processor.get_input(n_bits), n_bits); } + template + static void reveal_inst(GC::Processor& processor, const vector& args) + { processor.reveal(args); } template static void convcbit(Integer& dest, const GC::Clear& source, T&) @@ -274,10 +281,6 @@ class ProgramRegister : public Phase, public Register static int threshold(int) { throw not_implemented(); } - static Register new_reg(); - static Register tmp_reg() { return new_reg(); } - static Register and_reg() { return new_reg(); } - template static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } @@ -286,6 +289,15 @@ class ProgramRegister : public Phase, public Register static void inputbvec(T& processor, ProcessorBase& input_processor, const vector& args); + template + static void convcbit2s(GC::Processor&, const BaseInstruction&) + { throw runtime_error("convcbit2s not implemented"); } + template + static void andm(GC::Processor&, const BaseInstruction&) + { throw runtime_error("andm not implemented"); } + + static void run_tapes(const vector&) { throw not_implemented(); } + // most BMR phases don't need actual input template static T get_input(GC::Processor& processor, const InputArgs& args) @@ -295,8 +307,6 @@ class ProgramRegister : public Phase, public Register void other_input(Input&, int) {} char get_output() { return 0; } - - ProgramRegister(const Register& reg) : Register(reg) {} }; class PRFRegister : public ProgramRegister @@ -308,8 +318,6 @@ class PRFRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - PRFRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const PRFRegister& left, const PRFRegister& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char input = -1); @@ -385,8 +393,6 @@ class EvalRegister : public ProgramRegister static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>& proc); - EvalRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const ProgramRegister& left, const ProgramRegister& right, Function func); void XOR(const Register& left, const Register& right); @@ -416,8 +422,6 @@ class GarbleRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - GarbleRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const Register& left, const Register& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char value = -1); @@ -441,8 +445,6 @@ class RandomRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - RandomRegister(const Register& reg) : ProgramRegister(reg) {} - void randomize(); void op(const Register& left, const Register& right, Function func); @@ -458,12 +460,6 @@ class RandomRegister : public ProgramRegister }; -inline Register::Register(int n_parties) : - garbled_entry(n_parties), external(NO_SIGNAL), - mask(NO_SIGNAL), keys(n_parties) -{ -} - inline void KeyVector::operator=(const KeyVector& other) { resize(other.size()); diff --git a/BMR/Register.hpp b/BMR/Register.hpp index bd214a858..617906945 100644 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor, const vector& args) { NoOpInputter inputter; - int my_num = -1; - try - { - my_num = ProgramParty::s().P->my_num(); - } - catch (exception&) - { - } - processor.inputbvec(inputter, input_processor, args, my_num); + processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P); } template @@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor, { EvalInputter inputter; processor.inputbvec(inputter, input_processor, args, - ProgramParty::s().P->my_num()); + *ProgramParty::s().P); } template diff --git a/BMR/Register_inline.h b/BMR/Register_inline.h index 6a275da64..7694c464d 100644 --- a/BMR/Register_inline.h +++ b/BMR/Register_inline.h @@ -9,10 +9,10 @@ #include "CommonParty.h" #include "Party.h" - -inline Register ProgramRegister::new_reg() +inline Register::Register() : + garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL), + mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties()) { - return Register(CommonParty::s().get_n_parties()); } #endif /* BMR_REGISTER_INLINE_H_ */ diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index 6bd1ba264..439bcfc73 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty() _received_gc_received = 0; n_received = 0; randomfd = open("/dev/urandom", O_RDONLY); + done_filling = false; +} + +BaseTrustedParty::~BaseTrustedParty() +{ + close(randomfd); } TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h index 24e8120de..260e7a516 100644 --- a/BMR/TrustedParty.h +++ b/BMR/TrustedParty.h @@ -20,7 +20,7 @@ class BaseTrustedParty : virtual public CommonFakeParty { vector msg_input_masks; BaseTrustedParty(); - virtual ~BaseTrustedParty() {} + virtual ~BaseTrustedParty(); /* From NodeUpdatable class */ virtual void NodeReady(); @@ -104,7 +104,6 @@ class TrustedProgramParty : public BaseTrustedParty { void add_all_keys(const Register& reg, bool external); }; - inline void BaseTrustedParty::add_keys(const Register& reg) { for(int p = 0; p < get_n_parties(); p++) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e23aaa68..9a3a276d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,95 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.5 (Feb 16, 2023) + +- Easier-to-use machine learning interface +- Integrated compilation-execution facility +- Import/export sequential models and parameters from/to PyTorch +- Binary-format input files +- Less aggressive round optimization for faster compilation by default +- Multithreading with client interface +- Functionality to protect order of specific memory accesses +- Oblivious transfer works again on older (pre-2011) x86 CPUs +- clang is used by default + +## 0.3.4 (Nov 9, 2022) + +- Decision tree learning +- Optimized oblivious shuffle in Rep3 +- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC +- Optimized element-vector AND in SemiBin +- Optimized input protocol in Shamir-based protocols +- Square-root ORAM (@Quitlox) +- Improved ORAM in binary circuits +- UTF-8 outputs + +## 0.3.3 (Aug 25, 2022) + +- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate +- Fix security bug in MAC check when using multithreading +- Fix security bug to prevent selective failure attack by checking earlier +- Fix security bug in Mama: insufficient sacrifice. +- Inverse permutation (@Quitlox) +- Easier direct compilation (@eriktaubeneck) +- Generally allow element-vector operations +- Increase maximum register size to 2^54 +- Client example in Python +- Uniform base OTs across platforms +- Multithreaded base OT computation +- Faster random bit generation in two-player Semi(2k) + +## 0.3.2 (May 27, 2022) + +- Secure shuffling +- O(n log n) radix sorting +- Documented BGV encryption interface +- Optimized matrix multiplication in dealer protocol +- Fixed security bug in homomorphic encryption parameter generation +- Fixed security bug in Temi matrix multiplication + +## 0.3.1 (Apr 19, 2022) + +- Protocol in dealer model +- Command-line option for security parameter +- Fixed security bug in SPDZ2k (see Section 3.4 of [the updated paper](https://eprint.iacr.org/2018/482)) +- Ability to run high-level (Python) code from C++ +- More memory capacity due to 64-bit addressing +- Homomorphic encryption for more fields of characteristic two +- Docker container + +## 0.3.0 (Feb 17, 2022) + +- Semi-honest computation based on threshold semi-homomorphic encryption +- Batch normalization backward propagation +- AlexNet for CIFAR-10 +- Specific private output protocols +- Semi-honest additive secret sharing without communication +- Sending of personal values +- Allow overwriting of persistence files +- Protocol signature in persistence files + +## 0.2.9 (Jan 11, 2022) + +- Disassembler +- Run-time parameter for probabilistic truncation error +- Probabilistic truncation for some protocols computing modulo a prime +- Simplified C++ interface +- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757) +- More general scalar-vector multiplication +- Complete memory support for clear bits +- Extended clear bit functionality with Yao's garbled circuits +- Allow preprocessing information to be supplied via named pipes +- In-place operations for containers + +## 0.2.8 (Nov 4, 2021) + +- Tested on Apple laptop with ARM chip +- Restore trusted client interface +- Directly accessible softmax function +- Signature in preprocessing files to reduce confusing errors +- Improved error messages for connection issues +- Documentation of low-level share types and protocol pairs + ## 0.2.7 (Sep 17, 2021) - Optimized matrix multiplication in Hemi diff --git a/CONFIG b/CONFIG index 5f12d2c33..6d5f0f170 100644 --- a/CONFIG +++ b/CONFIG @@ -8,6 +8,9 @@ GDEBUG = -g # set this to your preferred local storage directory PREP_DIR = '-DPREP_DIR="Player-Data/"' +# directory to store SSL keys +SSL_DIR = '-DSSL_DIR="Player-Data/"' + # set for SHE preprocessing (SPDZ and Overdrive) USE_NTL = 0 @@ -28,25 +31,23 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx ARCH = -march=native MACHINE := $(shell uname -m) +ARM := $(shell uname -m | grep x86; echo $$?) OS := $(shell uname -s) ifeq ($(MACHINE), x86_64) -# set this to 0 to avoid using AVX for OT ifeq ($(OS), Linux) -CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?) -ifeq ($(CHECK_AVX), 0) AVX_OT = 1 else AVX_OT = 0 endif else -AVX_OT = 1 -endif -else +ARCH = AVX_OT = 0 endif +USE_KOS = 0 + # allow to set compiler in CONFIG.mine -CXX = g++ +CXX = clang++ # use CONFIG.mine to overwrite DIR settings -include CONFIG.mine @@ -66,8 +67,11 @@ endif # MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5 LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) +LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib LDLIBS += -lboost_system -lssl -lcrypto +CFLAGS += -I./local/include + ifeq ($(USE_NTL),1) CFLAGS += -DUSE_NTL LDLIBS := -lntl $(LDLIBS) @@ -83,12 +87,20 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) ifeq ($(OS), Darwin) +# for boost with OpenSSL 3 +CFLAGS += -Wno-error=deprecated-declarations ifeq ($(USE_NTL),1) -CFLAGS += -Wno-error=unused-parameter +CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy +endif endif + +ifeq ($(USE_KOS),1) +CFLAGS += -DUSE_KOS +else +CFLAGS += -std=c++17 endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index f1b5ad236..73a8af216 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -13,6 +13,7 @@ import Compiler.tools as tools import collections import itertools +import math class SecretBitsAF(base.RegisterArgFormat): reg_type = 'sb' @@ -50,6 +51,7 @@ class ClearBitsAF(base.RegisterArgFormat): INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, + ANDRSVEC = 0x24a, XORCBI = 0x210, BITDECC = 0x211, NOTCB = 0x212, @@ -64,6 +66,8 @@ class ClearBitsAF(base.RegisterArgFormat): MULCBI = 0x21c, SHRCBI = 0x21d, SHLCBI = 0x21e, + LDMCBI = 0x258, + STMCBI = 0x259, CONVCINTVEC = 0x21f, PRINTREGSIGNED = 0x220, PRINTREGB = 0x221, @@ -153,6 +157,52 @@ class andrs(BinaryVectorInstruction): def add_usage(self, req_node): req_node.increment(('bit', 'triple'), sum(self.args[::4])) + req_node.increment(('bit', 'mixed'), + sum(int(math.ceil(x / 64)) for x in self.args[::4])) + +class andrsvec(base.VarArgsInstruction, base.Mergeable, + base.DynFormatInstruction): + """ Constant-vector AND of secret bit registers (vectorized version). + + :param: total number of arguments to follow (int) + :param: number of arguments to follow for one operation / + operation vector size plus three (int) + :param: vector size (int) + :param: result vector (sbit) + :param: (repeat)... + :param: constant operand (sbits) + :param: vector operand + :param: (repeat)... + :param: (repeat from number of arguments to follow for one operation)... + + """ + code = opcodes['ANDRSVEC'] + + def __init__(self, *args, **kwargs): + super(andrsvec, self).__init__(*args, **kwargs) + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + for x in self.args[i + 2:i + n]: + assert x.n == size + + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + n_args = (n - 3) // 2 + assert n_args > 0 + for j in range(n_args): + yield 'sbw' + for j in range(n_args + 1): + yield 'sb' + yield 'int' + + def add_usage(self, req_node): + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + req_node.increment(('bit', 'triple'), size * (n - 3) // 2) + req_node.increment(('bit', 'mixed'), size) class ands(BinaryVectorInstruction): """ Bitwise AND of secret bit register vector. @@ -303,7 +353,7 @@ class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, :param: memory address (int) """ code = opcodes['LDMSB'] - arg_format = ['sbw','int'] + arg_format = ['sbw','long'] class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction): """ Copy secret bit register to secret bit memory cell with compile-time @@ -313,7 +363,7 @@ class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction): :param: memory address (int) """ code = opcodes['STMSB'] - arg_format = ['sb','int'] + arg_format = ['sb','long'] # def __init__(self, *args, **kwargs): # super(type(self), self).__init__(*args, **kwargs) # import inspect @@ -328,7 +378,7 @@ class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, :param: memory address (int) """ code = opcodes['LDMCB'] - arg_format = ['cbw','int'] + arg_format = ['cbw','long'] class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): """ Copy clear bit register to clear bit memory cell with compile-time @@ -338,9 +388,10 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): :param: memory address (int) """ code = opcodes['STMCB'] - arg_format = ['cb','int'] + arg_format = ['cb','long'] -class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit memory cell with run-time address to secret bit register. @@ -349,8 +400,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMSBI'] arg_format = ['sbw','ci'] + direct = staticmethod(ldmsb) -class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit register to secret bit memory cell with run-time address. @@ -359,6 +412,31 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMSBI'] arg_format = ['sb','ci'] + direct = staticmethod(stmsb) + +class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): + """ Copy clear bit memory cell with run-time address to clear bit + register. + + :param: destination (cbit) + :param: memory address (regint) + """ + code = opcodes['LDMCBI'] + arg_format = ['cbw','ci'] + direct = staticmethod(ldmcb) + +class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): + """ Copy clear bit register to clear bit memory cell with run-time + address. + + :param: source (cbit) + :param: memory address (regint) + """ + code = opcodes['STMCBI'] + arg_format = ['cb','ci'] + direct = staticmethod(stmcb) class ldmsdi(base.ReadMemoryInstruction): code = opcodes['LDMSDI'] @@ -475,7 +553,7 @@ class movsb(NonVectorInstruction): code = opcodes['MOVSB'] arg_format = ['sbw','sb'] -class trans(base.VarArgsInstruction): +class trans(base.VarArgsInstruction, base.DynFormatInstruction): """ Secret bit register vector transpose. The first destination vector will contain the least significant bits of all source vectors etc. @@ -489,10 +567,22 @@ class trans(base.VarArgsInstruction): code = opcodes['TRANS'] is_vec = lambda self: True def __init__(self, *args): - self.arg_format = ['int'] + ['sbw'] * args[0] + \ - ['sb'] * (len(args) - 1 - args[0]) super(trans, self).__init__(*args) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + n = next(args) + for i in range(n): + yield 'sbw' + next(args) + while True: + try: + yield 'sb' + next(args) + except StopIteration: + break + class bitb(NonVectorInstruction): """ Copy fresh secret random bit to secret bit register. @@ -538,7 +628,7 @@ def add_usage(self, req_node): req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1]) class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, - base.Mergeable): + base.Mergeable, base.DynFormatInstruction): """ Copy private input to secret bit registers bit by bit. The input is read as floating-point number, multiplied by a power of two, rounded to an integer, and then decomposed into bits. @@ -555,11 +645,19 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, code = opcodes['INPUTBVEC'] def __init__(self, *args, **kwargs): - self.arg_format = [] - for x in self.get_arg_tuples(args): - self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3) super(inputbvec, self).__init__(*args, **kwargs) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + yield 'p' + assert n > 3 + for j in range(n - 3): + yield 'sbw' + yield 'int' + @staticmethod def get_arg_tuples(args): i = 0 @@ -568,10 +666,6 @@ def get_arg_tuples(args): i += args[i] assert i == len(args) - def merge(self, other): - self.args += other.args - self.arg_format += other.arg_format - def add_usage(self, req_node): for x in self.get_arg_tuples(self.args): req_node.increment(('bit', 'input', x[2]), x[0] - 3) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index a1384475b..d353c148a 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -3,10 +3,13 @@ fixed-length types obtained by :py:obj:`get_type(n)` are the preferred way of using them, and in some cases required in connection with container types. + +Computation using these types will always be executed as a binary +circuit. See :ref:`protocol-pairs` for the exact protocols. """ from Compiler.types import MemValue, read_mem_value, regint, Array, cint -from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint +from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint, sintbit from Compiler.program import Tape, Program from Compiler.exceptions import * from Compiler import util, oram, floatingpoint, library @@ -14,10 +17,10 @@ import Compiler.GC.instructions as inst import operator import math +import itertools from functools import reduce class bits(Tape.Register, _structure, _bit): - """ Base class for binary registers. """ n = 40 unit = 64 PreOp = staticmethod(floatingpoint.PreOpN) @@ -41,7 +44,7 @@ class bitsn(cls): return cls.types[length] @classmethod def conv(cls, other): - if isinstance(other, cls): + if isinstance(other, cls) and cls.n == other.n: return other elif isinstance(other, MemValue): return cls.conv(other.read()) @@ -56,12 +59,12 @@ def compose(cls, items, bit_length=1): @classmethod def bit_compose(cls, bits): bits = list(bits) - if len(bits) == 1: + if len(bits) == 1 and isinstance(bits[0], cls): return bits[0] bits = list(bits) for i in range(len(bits)): if util.is_constant(bits[i]): - bits[i] = sbit(bits[i]) + bits[i] = cls.bit_type(bits[i]) res = cls.new(n=len(bits)) if len(bits) <= cls.unit: cls.bitcom(res, *(sbit.conv(bit) for bit in bits)) @@ -111,11 +114,16 @@ def load_mem(cls, address, mem_type=None, size=None): if mem_type == 'sd': return cls.load_dynamic_mem(address) else: - for i in range(res.size): - cls.load_inst[util.is_constant(address)](res[i], address + i) + cls.mem_op(cls.load_inst, res, address) return res def store_in_mem(self, address): - self.store_inst[isinstance(address, int)](self, address) + self.mem_op(self.store_inst, self, address) + @staticmethod + def mem_op(inst, reg, address): + direct = isinstance(address, int) + if not direct: + address = regint.conv(address) + inst[direct](reg, address) @classmethod def new(cls, value=None, n=None): if util.is_constant(value): @@ -147,19 +155,29 @@ def load_other(self, other): self.set_length(self.n or util.int_len(other)) self.load_int(other) elif isinstance(other, regint): - assert(other.size == math.ceil(self.n / self.unit)) - for i, (x, y) in enumerate(zip(self, other)): + assert self.unit == 64 + n_units = int(math.ceil(self.n / self.unit)) + n_convs = min(other.size, n_units) + for i in range(n_convs): + x = self[i] + y = other[i] self.conv_regint(min(self.unit, self.n - i * self.unit), x, y) + for i in range(n_convs, n_units): + inst.ldbits(self[i], min(self.unit, self.n - i * self.unit), 0) elif (isinstance(self, type(other)) or isinstance(other, type(self))) \ and self.n == other.n: for i in range(math.ceil(self.n / self.unit)): self.mov(self[i], other[i]) + elif isinstance(other, sintbit) and isinstance(self, sbits): + assert len(other) == 1 + r = sint.get_dabit() + self.mov(self, r[1] ^ other.bit_xor(r[0]).reveal()) elif isinstance(other, sint) and isinstance(self, sbits): self.mov(self, sbitvec(other, self.n).elements()[0]) else: try: bits = other.bit_decompose() - bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits)) + bits = bits[:self.n] + [self.bit_type(0)] * (self.n - len(bits)) other = self.bit_compose(bits) assert(isinstance(other, type(self))) assert(other.n == self.n) @@ -184,6 +202,8 @@ def __and__(self, other): return 0 elif self.is_long_one(other): return self + elif isinstance(other, _vec): + return other & other.from_vec([self]) else: return self._and(other) @read_mem_value @@ -222,17 +242,30 @@ def if_else(self, x, y): This will output 1. """ return result_conv(x, y)(self & (x ^ y) ^ y) + def zero_if_not(self, condition): + if util.is_constant(condition): + return self * condition + else: + return self * cbit.conv(condition) + def expand(self, length): + if self.n in (length, None): + return self + elif self.n == 1: + return self.get_type(length).bit_compose([self] * length) + else: + raise CompilerError('cannot expand from %s to %s' % (self.n, length)) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ max_length = 64 reg_type = 'cb' is_clear = True - load_inst = (None, inst.ldmcb) - store_inst = (None, inst.stmcb) + load_inst = (inst.ldmcbi, inst.ldmcb) + store_inst = (inst.stmcbi, inst.stmcb) bitdec = inst.bitdecc conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y)) conv_cint_vec = inst.convcintvec + mov = staticmethod(lambda x, y: inst.addcbi(x, y, 0)) @classmethod def bit_compose(cls, bits): return sum(bit << i for i, bit in enumerate(bits)) @@ -241,14 +274,26 @@ def conv_regint_by_bit(cls, n, res, other): assert n == res.n assert n == other.size cls.conv_cint_vec(cint(other, size=other.size), res) + @classmethod + def conv(cls, other): + if isinstance(other, cbits) and cls.n != None and \ + cls.n // cls.unit == other.n // cls.unit: + if isinstance(other, cls): + return other + else: + res = cls() + for i in range(math.ceil(cls.n / cls.unit)): + cls.mov(res[i], other[i]) + return res + else: + return super(cbits, cls).conv(other) types = {} def load_int(self, value): - if self.n <= 64: - tmp = regint(value) - elif value == self.long_one(): - tmp = cint(1, size=self.n) - else: - raise CompilerError('loading long integers to cbits not supported') + n_limbs = math.ceil(self.n / self.unit) + tmp = regint(size=n_limbs) + for i in range(n_limbs): + tmp[i].load_int(value % 2 ** self.unit) + value >>= self.unit self.load_other(tmp) def store_in_dynamic_mem(self, address): inst.stmsdci(self, cbits.conv(address)) @@ -270,8 +315,15 @@ def clear_op(self, other, c_inst, ci_inst, op): return op(self, cbits(other)) __add__ = lambda self, other: \ self.clear_op(other, inst.addcb, inst.addcbi, operator.add) - __sub__ = lambda self, other: \ - self.clear_op(-other, inst.addcb, inst.addcbi, operator.add) + def __sub__(self, other): + try: + return self + -other + except: + return type(self)(regint(self) - regint(other)) + def __rsub__(self, other): + return type(self)(other - regint(self)) + def __neg__(self): + return type(self)(-regint(self)) def _xor(self, other): if isinstance(other, (sbits, sbitvec)): return NotImplemented @@ -309,6 +361,8 @@ def __invert__(self): res = type(self)() inst.notcb(self.n, res, self) return res + def __eq__(self, other): + raise CompilerError('equality not implemented') def print_reg(self, desc=''): inst.print_regb(self, desc) def print_reg_plain(self): @@ -361,7 +415,6 @@ class sbits(bits): reg_type = 'sb' is_clear = False clear_type = cbits - default_type = cbits load_inst = (inst.ldmsbi, inst.ldmsb) store_inst = (inst.stmsbi, inst.stmsb) bitdec = inst.bitdecs @@ -383,16 +436,25 @@ def new(value=None, n=None): else: return sbits.get_type(n)(value) @staticmethod + def _new(value): + return value + @staticmethod def get_random_bit(): res = sbit() inst.bitb(res) return res + @staticmethod + def _check_input_player(player): + if not util.is_constant(player): + raise CompilerError('player must be known at compile time ' + 'for binary circuit inputs') @classmethod def get_input_from(cls, player, n_bits=None): """ Secret input from :py:obj:`player`. :param: player (int) """ + cls._check_input_player(player) if n_bits is None: n_bits = cls.n res = cls() @@ -461,6 +523,8 @@ def __mul__(self, other): if isinstance(other, int): return self.mul_int(other) try: + if (self.n, other.n) == (1, 1): + return self & other if min(self.n, other.n) != 1: raise NotImplementedError('high order multiplication') n = max(self.n, other.n) @@ -552,7 +616,15 @@ def trans(cls, rows): rows = list(rows) if len(rows) == 1 and rows[0].n <= rows[0].unit: return rows[0].bit_decompose() - n_columns = rows[0].n + for row in rows: + try: + n_columns = row.n + break + except: + pass + for i in range(len(rows)): + if util.is_zero(rows[i]): + rows[i] = cls.get_type(n_columns)(0) for row in rows: assert(row.n == n_columns) if n_columns == 1 and len(rows) <= cls.unit: @@ -576,7 +648,7 @@ def bit_adder(*args, **kwargs): def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) -class sbitvec(_vec): +class sbitvec(_vec, _bit): """ Vector of registers of secret bits, effectively a matrix of secret bits. This facilitates parallel arithmetic operations in binary circuits. Container types are not supported, use :py:obj:`sbitvec.get_type` for that. @@ -584,7 +656,7 @@ class sbitvec(_vec): You can access the rows by member :py:obj:`v` and the columns by calling :py:obj:`elements`. - There are three ways to create an instance: + There are four ways to create an instance: 1. By transposition:: @@ -617,8 +689,14 @@ class sbitvec(_vec): This should output:: [1, 0, 1] + + 4. Private input:: + + x = sbitvec.get_type(32).get_input_from(player) + """ bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) + is_clear = False @classmethod def get_type(cls, n): """ Create type for fixed-length vector of registers of secret bits. @@ -632,17 +710,28 @@ def malloc(size, creator_tape=None): return sbit.malloc(size * n, creator_tape=creator_tape) @staticmethod def n_elements(): + return 1 + @staticmethod + def mem_size(): return n @classmethod - def get_input_from(cls, player): + def get_input_from(cls, player, size=1, f=0): """ Secret input from :py:obj:`player`. The input is decomposed into bits. :param: player (int) """ - res = cls.from_vec(sbit() for i in range(n)) - inst.inputbvec(n + 3, 0, player, *res.v) - return res + v = [0] * n + sbits._check_input_player(player) + instructions_base.check_vector_size(size) + for i in range(size): + vv = [sbit() for i in range(n)] + inst.inputbvec(n + 3, f, player, *vv) + for j in range(n): + tmp = vv[j] << i + v[j] = tmp ^ v[j] + sbits._check_input_player(player) + return cls.from_vec(v) get_raw_input_from = get_input_from @classmethod def from_vec(cls, vector): @@ -650,10 +739,12 @@ def from_vec(cls, vector): res.v = _complement_two_extend(list(vector), n)[:n] return res def __init__(self, other=None, size=None): - assert size in (None, 1) + instructions_base.check_vector_size(size) if other is not None: if util.is_constant(other): - self.v = [sbit((other >> i) & 1) for i in range(n)] + t = sbits.get_type(size or 1) + self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1)) + for i in range(n)] elif isinstance(other, _vec): self.v = self.bit_extend(other.v, n) elif isinstance(other, (list, tuple)): @@ -661,36 +752,41 @@ def __init__(self, other=None, size=None): else: self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n + assert size is None or size == self.v[0].n @classmethod - def load_mem(cls, address): + def load_mem(cls, address, size=None): + if size not in (None, 1): + assert isinstance(address, int) or len(address) == 1 + sb = sbits.get_type(size) + return cls.from_vec(sb.bit_compose( + sbit.load_mem(address + i + j * n) for j in range(size)) + for i in range(n)) if not isinstance(address, int) and len(address) == n: return cls.from_vec(sbit.load_mem(x) for x in address) else: return cls.from_vec(sbit.load_mem(address + i) for i in range(n)) def store_in_mem(self, address): + size = 1 for x in self.v: - assert util.is_constant(x) or x.n == 1 - v = [sbit.conv(x) for x in self.v] + if not util.is_constant(x): + size = max(size, x.n) + v = [sbits.get_type(size).conv(x) for x in self.v] if not isinstance(address, int) and len(address) == n: + assert max_n == 1 for x, y in zip(v, address): x.store_in_mem(y) else: + assert isinstance(address, int) or len(address) == 1 for i in range(n): - v[i].store_in_mem(address + i) + for j, x in enumerate(v[i].bit_decompose()): + x.store_in_mem(address + i + j * n) def reveal(self): - if len(self) > cbits.unit: - return self.elements()[0].reveal() - revealed = [cbit() for i in range(len(self))] - for i in range(len(self)): - try: - inst.reveal(1, revealed[i], self.v[i]) - except: - revealed[i] = cbit.conv(self.v[i]) - return cbits.get_type(len(self)).bit_compose(revealed) + return util.untuplify([x.reveal() for x in self.elements()]) @classmethod - def two_power(cls, nn): - return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) + def two_power(cls, nn, size=1): + return cls.from_vec( + [0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1)) def coerce(self, other): if util.is_constant(other): return self.from_vec(util.bit_decompose(other, n)) @@ -703,8 +799,12 @@ def bit_compose(cls, bits): bits += [0] * (n - len(bits)) assert len(bits) == n return cls.from_vec(bits) + def zero_if_not(self, condition): + return self.from_vec(x.zero_if_not(condition) for x in self.v) def __str__(self): return 'sbitvec(%d)' % n + sbitvecn.basic_type = sbitvecn + sbitvecn.reg_type = 'sb' return sbitvecn @classmethod def from_vec(cls, vector): @@ -721,6 +821,15 @@ def combine(cls, vectors): def from_matrix(cls, matrix): # any number of rows, limited number of columns return cls.combine(cls(row) for row in matrix) + @classmethod + def from_hex(cls, string): + """ Create from hexadecimal string (little-endian). """ + assert len(string) % 2 == 0 + v = [] + for i in range(0, len(string), 2): + v += [sbit(int(x)) + for x in reversed(bin(int(string[i:i + 2], 16))[2:].zfill(8))] + return cls.from_vec(v) def __init__(self, elements=None, length=None, input_length=None): if length: assert isinstance(elements, sint) @@ -767,19 +876,20 @@ def coerce(self, other): size = other.size return (other.get_vector(base, min(64, size - base)) \ for base in range(0, size, 64)) + if not isinstance(other, type(self)): + return type(self)(other) return other def __xor__(self, other): other = self.coerce(other) - return self.from_vec(x ^ y for x, y in zip(self.v, other)) + return self.from_vec(x ^ y for x, y in zip(*self.expand(other))) def __and__(self, other): - return self.from_vec(x & y for x, y in zip(self.v, other.v)) + return self.from_vec(x & y for x, y in zip(*self.expand(other))) + __rxor__ = __xor__ + __rand__ = __and__ + def __invert__(self): + return self.from_vec(~x for x in self.v) def if_else(self, x, y): - assert(len(self.v) == 1) - try: - return self.from_vec(util.if_else(self.v[0], a, b) \ - for a, b in zip(x, y)) - except: - return util.if_else(self.v[0], x, y) + return util.if_else(self.v[0], x, y) def __iter__(self): return iter(self.v) def __len__(self): @@ -792,6 +902,7 @@ def conv(cls, other): return cls.from_vec(other.v) else: return cls(other) + hard_conv = conv @property def size(self): if not self.v or util.is_constant(self.v[0]): @@ -804,7 +915,7 @@ def n_bits(self): def store_in_mem(self, address): for i, x in enumerate(self.elements()): x.store_in_mem(address + i) - def bit_decompose(self, n_bits=None, security=None): + def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None): return self.v[:n_bits] bit_compose = from_vec def reveal(self): @@ -821,6 +932,34 @@ def half_adder(self, other): def __mul__(self, other): if isinstance(other, int): return self.from_vec(x * other for x in self.v) + if isinstance(other, sbitvec): + if len(other.v) == 1: + other = other.v[0] + elif len(self.v) == 1: + self, other = other, self.v[0] + else: + raise CompilerError('no operand of lenght 1: %d/%d', + (len(self.v), len(other.v))) + if not isinstance(other, sbits): + return NotImplemented + ops = [] + for x in self.v: + if not util.is_zero(x): + assert x.n == other.n + ops.append(x) + if ops: + prods = [sbits.get_type(other.n)() for i in ops] + inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops) + res = [] + i = 0 + for x in self.v: + if util.is_zero(x): + res.append(0) + else: + res.append(prods[i]) + i += 1 + return sbitvec.from_vec(res) + __rmul__ = __mul__ def __add__(self, other): return self.from_vec(x + y for x, y in zip(self.v, other)) def bit_and(self, other): @@ -829,6 +968,60 @@ def bit_xor(self, other): return self ^ other def right_shift(self, m, k, security=None, signed=True): return self.from_vec(self.v[m:]) + def tree_reduce(self, function): + elements = self.elements() + while len(elements) > 1: + size = len(elements) + half = size // 2 + left = elements[:half] + right = elements[half:2*half] + odd = elements[2*half:] + sides = [self.from_vec(sbitvec(x).v) for x in (left, right)] + red = function(*sides) + elements = red.elements() + elements += odd + return self.from_vec(sbitvec(elements).v) + @classmethod + def comp_result(cls, x): + return cls.get_type(1).from_vec([x]) + def expand(self, other, expand=True): + m = 1 + for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []): + try: + m = max(m, x.n) + except: + pass + res = [] + if not util.is_constant(other): + other = self.coerce(other) + for y in self, other: + if isinstance(y, int): + res.append([x * sbits.get_type(m)().long_one() + for x in util.bit_decompose(y, len(self.v))]) + else: + res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v]) + return res + def demux(self): + if len(self) == 1: + return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]]) + a = sbitvec.from_vec(self.v[:len(self) // 2]).demux() + b = sbitvec.from_vec(self.v[len(self) // 2:]).demux() + prod = [a * bb for bb in b.v] + return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod))) + def reverse_bytes(self): + if len(self.v) % 8 != 0: + raise CompilerError('bit length not divisible by eight') + return self.from_vec(sum(reversed( + [self.v[i:i + 8] for i in range(0, len(self.v), 8)]), [])) + def reveal_print_hex(self): + """ Reveal and print in hexademical (one line per element). """ + for x in self.reverse_bytes().elements(): + x.reveal().print_reg() + def update(self, other): + other = self.conv(other) + assert len(self.v) == len(other.v) + for x, y in zip(self.v, other.v): + x.update(y) class bit(object): n = 1 @@ -879,10 +1072,11 @@ class cbit(bit, cbits): sbits.bit_type = sbit cbits.bit_type = cbit sbit.clear_type = cbit +sbits.default_type = sbits class bitsBlock(oram.Block): - value_type = sbits def __init__(self, value, start, lengths, entries_per_block): + self.value_type = type(value) oram.Block.__init__(self, value, lengths) length = sum(self.lengths) used_bits = entries_per_block * length @@ -927,7 +1121,10 @@ def _store(self, value, address): cbits.dynamic_array = Array def _complement_two_extend(bits, k): - return bits[:k] + [bits[-1]] * (k - len(bits)) + if len(bits) == 1: + return bits + [0] * (k - len(bits)) + else: + return bits[:k] + [bits[-1]] * (k - len(bits)) class _sbitintbase: def extend(self, n): @@ -986,6 +1183,9 @@ class sbitint(_bitint, _number, sbits, _sbitintbase): mul: 15 lt: 0 + This class is retained for compatibility, but development now + focuses on :py:class:`sbitintvec`. + """ n_bits = None bin_type = None @@ -1077,7 +1277,7 @@ def pow2(self, k): :param k: bit length of input """ return _sbitintbase.pow2(self, k) -class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): +class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ Vector of signed integers for parallel binary computation:: @@ -1112,19 +1312,34 @@ def elements(self): def __add__(self, other): if util.is_zero(other): return self - other = self.coerce(other) - assert(len(self.v) == len(other.v)) - v = sbitint.bit_adder(self.v, other.v) - return self.from_vec(v) + a, b = self.expand(other) + v = sbitint.bit_adder(a, b) + return self.get_type(len(v)).from_vec(v) __radd__ = __add__ + __sub__ = _bitint.__sub__ + def __rsub__(self, other): + a, b = self.expand(other) + return self.from_vec(b) - self.from_vec(a) def __mul__(self, other): if isinstance(other, sbits): return self.from_vec(other * x for x in self.v) + elif len(self.v) == 1: + return other * self.v[0] elif isinstance(other, sbitfixvec): return NotImplemented + my_bits, other_bits = self.expand(other, False) matrix = [] - for i, b in enumerate(util.bit_decompose(other)): - matrix.append([x & b for x in self.v[:len(self.v)-i]]) + m = float('inf') + for x in itertools.chain(my_bits, other_bits): + try: + m = min(m, x.n) + except: + pass + for i, b in enumerate(other_bits): + if m == 1: + matrix.append([x * b for x in my_bits[:len(self.v)-i]]) + else: + matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ @@ -1155,22 +1370,27 @@ class cbitfix(object): store_in_mem = lambda self, *args: self.v.store_in_mem(*args) @classmethod def _new(cls, value): + if isinstance(value, list): + return [cls._new(x) for x in value] res = cls() + if cls.k < value.unit: + bits = value.bit_decompose(cls.k) + sign = bits[-1] + value += (sign << (cls.k)) * -1 res.v = value return res def output(self): v = self.v - if self.k < v.unit: - bits = self.v.bit_decompose(self.k) - sign = bits[-1] - v += (sign << (self.k)) * -1 inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), cbits(0), cbits(0)) class sbitfix(_fix): - """ Secret signed integer in one binary register. + """ Secret signed fixed-point number in one binary register. Use :py:obj:`set_precision()` to change the precision. + This class is retained for compatibility, but development now + focuses on :py:class:`sbitfixvec`. + Example:: print_ln('add: %s', (sbitfix(0.5) + sbitfix(0.3)).reveal()) @@ -1209,6 +1429,7 @@ def get_input_from(cls, player): :param: player (int) """ + sbits._check_input_player(player) v = cls.int_type() inst.inputb(player, cls.k, cls.f, v) return cls._new(v) @@ -1231,7 +1452,7 @@ class cls(_fix): cls.set_precision(f, k) return cls._new(cls.int_type(other), k, f) -class sbitfixvec(_fix): +class sbitfixvec(_fix, _vec): """ Vector of fixed-point numbers for parallel binary computation. Use :py:obj:`set_precision()` to change the precision. @@ -1260,23 +1481,27 @@ class sbitfixvec(_fix): int_type = sbitintvec.get_type(sbitfix.k) float_type = type(None) clear_type = cbitfix + @property + def bit_type(self): + return type(self.v[0]) @classmethod def set_precision(cls, f, k=None): super(sbitfixvec, cls).set_precision(f=f, k=k) cls.int_type = sbitintvec.get_type(cls.k) @classmethod - def get_input_from(cls, player): + def get_input_from(cls, player, size=1): """ Secret input from :py:obj:`player`. :param: player (int) """ - v = [sbit() for i in range(sbitfix.k)] - inst.inputbvec(len(v) + 3, sbitfix.f, player, *v) - return cls._new(cls.int_type.from_vec(v)) + return cls._new(cls.int_type.get_input_from(player, size=size, + f=sbitfix.f)) def __init__(self, value=None, *args, **kwargs): if isinstance(value, (list, tuple)): self.v = self.int_type.from_vec(sbitvec([x.v for x in value])) else: + if isinstance(value, sbitvec): + value = self.int_type(value) super(sbitfixvec, self).__init__(value, *args, **kwargs) def elements(self): return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()] @@ -1286,9 +1511,12 @@ def mul(self, other): else: return super(sbitfixvec, self).mul(other) def __xor__(self, other): + if util.is_zero(other): + return self return self._new(self.v ^ other.v) def __and__(self, other): return self._new(self.v & other.v) + __rxor__ = __xor__ @staticmethod def multipliable(other, k, f, size): class cls(_fix): diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 9a22da461..6a0d6b1d5 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -2,30 +2,3 @@ from .GC import types as GC_types import inspect from .config import * -from .compilerLib import run - - -# add all instructions to the program VARS dictionary -compilerLib.VARS = {} -instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] - -for mod in (types, GC_types): - instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ - if t[1].__module__ == mod.__name__] - -instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\ - if t[1].__module__ == library.__name__] - -for op in instr_classes: - compilerLib.VARS[op.__name__] = op - -# add open and input separately due to name conflict -compilerLib.VARS['open'] = instructions.asm_open -compilerLib.VARS['vopen'] = instructions.vasm_open -compilerLib.VARS['gopen'] = instructions.gasm_open -compilerLib.VARS['vgopen'] = instructions.vgasm_open -compilerLib.VARS['input'] = instructions.asm_input -compilerLib.VARS['ginput'] = instructions.gasm_input - -compilerLib.VARS['comparison'] = comparison -compilerLib.VARS['floatingpoint'] = floatingpoint diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 6a1472d7b..980a189a4 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -15,10 +15,12 @@ class BlockAllocator: """ Manages freed memory blocks. """ def __init__(self): - self.by_logsize = [defaultdict(set) for i in range(32)] + self.by_logsize = [defaultdict(set) for i in range(64)] self.by_address = {} def by_size(self, size): + if size >= 2 ** 64: + raise CompilerError('size exceeds addressing capability') return self.by_logsize[int(math.log(size, 2))][size] def push(self, address, size): @@ -99,6 +101,7 @@ def dealloc_reg(self, reg, inst, free): self.dealloc |= reg.vector else: self.dealloc.add(reg) + reg.duplicates.remove(reg) base = reg.vectorbase seen = set_by_id() @@ -169,7 +172,7 @@ def finalize(self, options): for reg in self.alloc: for x in reg.get_all(): if x not in self.dealloc and reg not in self.dealloc \ - and len(x.duplicates) == 1: + and len(x.duplicates) == 0: print('Warning: read before write at register', x) print('\tregister trace: %s' % format_trace(x.caller, '\t\t')) @@ -259,6 +262,7 @@ def longest_paths_merge(self): instructions = self.instructions merge_nodes = self.open_nodes depths = self.depths + self.req_num = defaultdict(lambda: 0) if not merge_nodes: return 0 @@ -279,6 +283,7 @@ def longest_paths_merge(self): print('Merging %d %s in round %d/%d' % \ (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) + self.req_num[t.__name__, 'round'] += 1 preorder = None @@ -310,7 +315,6 @@ def dependency_graph(self, merge_classes): last_def = defaultdict_by_id(lambda: -1) last_mem_write = [] last_mem_read = [] - warned_about_mem = [] last_mem_write_of = defaultdict(list) last_mem_read_of = defaultdict(list) last_print_str = None @@ -359,20 +363,22 @@ def mem_access(n, instr, last_access_this_kind, last_access_other_kind): addr_i = addr + i handle_mem_access(addr_i, reg_type, last_access_this_kind, last_access_other_kind) - if block.warn_about_mem and not warned_about_mem and \ - (instr.get_size() > 100): + if block.warn_about_mem and \ + not block.parent.warned_about_mem and \ + (instr.get_size() > 100) and not instr._protect: print('WARNING: Order of memory instructions ' \ 'not preserved due to long vector, errors possible') - warned_about_mem.append(True) + block.parent.warned_about_mem = True else: handle_mem_access(addr, reg_type, last_access_this_kind, last_access_other_kind) - if block.warn_about_mem and not warned_about_mem and \ - not isinstance(instr, DirectMemoryInstruction): + if block.warn_about_mem and \ + not block.parent.warned_about_mem and \ + not isinstance(instr, DirectMemoryInstruction) and \ + not instr._protect: print('WARNING: Order of memory instructions ' \ 'not preserved, errors possible') - # hack - warned_about_mem.append(True) + block.parent.warned_about_mem = True def strict_mem_access(n, last_this_kind, last_other_kind): if last_other_kind and last_this_kind and \ @@ -401,6 +407,20 @@ def keep_merged_order(instr, n, t): add_edge(last_input[t][1], n) last_input[t][0] = n + def keep_text_order(inst, n): + if inst.get_players() is None: + # switch + for x in list(last_input.keys()): + if isinstance(x, int): + add_edge(last_input[x][0], n) + del last_input[x] + keep_merged_order(instr, n, None) + elif last_input[None][0] is not None: + keep_merged_order(instr, n, None) + else: + for player in inst.get_players(): + keep_merged_order(instr, n, player) + for n,instr in enumerate(block.instructions): outputs,inputs = instr.get_def(), instr.get_used() @@ -425,7 +445,7 @@ def keep_merged_order(instr, n, t): # will be merged if isinstance(instr, TextInputInstruction): - keep_merged_order(instr, n, TextInputInstruction) + keep_text_order(instr, n) elif isinstance(instr, RawInputInstruction): keep_merged_order(instr, n, RawInputInstruction) @@ -454,14 +474,14 @@ def keep_merged_order(instr, n, t): depths[n] = depth if isinstance(instr, ReadMemoryInstruction): - if options.preserve_mem_order: + if options.preserve_mem_order or instr._protect: strict_mem_access(n, last_mem_read, last_mem_write) - else: + elif not options.preserve_mem_order: mem_access(n, instr, last_mem_read_of, last_mem_write_of) elif isinstance(instr, WriteMemoryInstruction): - if options.preserve_mem_order: + if options.preserve_mem_order or instr._protect: strict_mem_access(n, last_mem_write, last_mem_read) - else: + elif not options.preserve_mem_order: mem_access(n, instr, last_mem_write_of, last_mem_read_of) elif isinstance(instr, matmulsm): if options.preserve_mem_order: @@ -476,11 +496,7 @@ def keep_merged_order(instr, n, t): add_edge(last_print_str, n) last_print_str = n elif isinstance(instr, PublicFileIOInstruction): - keep_order(instr, n, instr.__class__) - elif isinstance(instr, startprivateoutput_class): - keep_order(instr, n, startprivateoutput_class, 2) - elif isinstance(instr, stopprivateoutput_class): - keep_order(instr, n, stopprivateoutput_class, 2) + keep_order(instr, n, PublicFileIOInstruction) elif isinstance(instr, prep_class): keep_order(instr, n, instr.args[0]) elif isinstance(instr, StackInstruction): @@ -518,7 +534,9 @@ def eliminate_dead_code(self): can_eliminate_defs = True for reg in inst.get_def(): for dup in reg.duplicates: - if not dup.can_eliminate: + if not (dup.can_eliminate and reduce( + operator.and_, + (x.can_eliminate for x in dup.vector), True)): can_eliminate_defs = False break # remove if instruction has result that isn't used @@ -569,7 +587,7 @@ class RegintOptimizer: def __init__(self): self.cache = util.dict_by_id() - def run(self, instructions): + def run(self, instructions, program): for i, inst in enumerate(instructions): if isinstance(inst, ldint_class): self.cache[inst.args[0]] = inst.args[1] @@ -581,16 +599,10 @@ def run(self, instructions): self.cache[inst.args[0]] = res instructions[i] = ldint(inst.args[0], res, add_to_prog=False) - elif isinstance(inst, addint_class): - if inst.args[1] in self.cache and \ - self.cache[inst.args[1]] == 0: - instructions[i] = inst.args[0].link(inst.args[2]) - elif inst.args[2] in self.cache and \ - self.cache[inst.args[2]] == 0: - instructions[i] = inst.args[0].link(inst.args[1]) elif isinstance(inst, IndirectMemoryInstruction): if inst.args[1] in self.cache: instructions[i] = inst.get_direct(self.cache[inst.args[1]]) + instructions[i]._protect = inst._protect elif type(inst) == convint_class: if inst.args[1] in self.cache: res = self.cache[inst.args[1]] @@ -604,7 +616,13 @@ def run(self, instructions): if op == 0: instructions[i] = ldsi(inst.args[0], 0, add_to_prog=False) - elif op == 1: + elif isinstance(inst, (crash, cond_print_str, cond_print_plain)): + if inst.args[0] in self.cache: + cond = self.cache[inst.args[0]] + if not cond: instructions[i] = None - inst.args[0].link(inst.args[1]) + pre = len(instructions) instructions[:] = list(filter(lambda x: x is not None, instructions)) + post = len(instructions) + if pre != post and program.options.verbose: + print('regint optimizer removed %d instructions' % (pre - post)) diff --git a/Compiler/circuit.py b/Compiler/circuit.py index 9c4187f75..395c66146 100644 --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -63,7 +63,7 @@ def run(self, *inputs): i = 0 for l in self.n_output_wires: v = [] - for i in range(l): + for j in range(l): v.append(flat_res[i]) i += 1 res.append(sbitvec.from_vec(v)) @@ -127,18 +127,24 @@ def sha3_256(x): from circuit import sha3_256 a = sbitvec.from_vec([]) - b = sbitvec(sint(0xcc), 8, 8) - for x in a, b: - sha3_256(x).elements()[0].reveal().print_reg() + b = sbitvec.from_hex('cc') + c = sbitvec.from_hex('41fb') + d = sbitvec.from_hex('1f877c') + e = sbitvec.from_vec([sbit(0)] * 8) + for x in a, b, c, d, e: + sha3_256(x).reveal_print_hex() + + This should output the `test vectors + `_ + of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the + 0 byte:: + + Reg[0] = 0xa7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a # + Reg[0] = 0x677035391cd3701293d385f037ba32796252bb7ce180b00b582dd9b20aaad7f0 # + Reg[0] = 0x39f31b6e653dfcd9caed2602fd87f61b6254f581312fb6eeec4d7148fa2e72aa # + Reg[0] = 0xbc22345e4bd3f792a341cf18ac0789f1c9c966712a501b19d1b6632ccd408ec5 # + Reg[0] = 0x5d53469f20fef4f8eab52b88044ede69c77a6a68a60728609fc4a65ff531e7d0 # - This should output the first two test vectors of SHA3-256 in - byte-reversed order:: - - 0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7 - 0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067 - - Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only - implemented for computation modulo a power of two. """ global Keccak_f @@ -236,10 +242,10 @@ def circuit(cls, name): return cls._circuits[name] def __init__(self, value): - if isinstance(value, sbitvec): + if isinstance(value, (sbitint, sbitintvec)): + self.value = self.circuit('i2f')(sbitvec.conv(value)) + elif isinstance(value, sbitvec): self.value = value - elif isinstance(value, (sbitint, sbitintvec)): - self.value = self.circuit('i2f')(sbitvec(value)) elif util.is_constant_float(value): self.value = sbitvec(sbits.get_type(64)( struct.unpack('Q', struct.pack('d', value))[0])) diff --git a/Compiler/circuit_oram.py b/Compiler/circuit_oram.py index f5ddebfd6..a2cada540 100644 --- a/Compiler/circuit_oram.py +++ b/Compiler/circuit_oram.py @@ -1,5 +1,6 @@ -from Compiler.path_oram import * +from Compiler.oram import * +from Compiler.path_oram import PathORAM, XOR from Compiler.util import bit_compose def first_diff(a_bits, b_bits): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index f4cf89ad6..1a139ef6d 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -77,30 +77,28 @@ def LTZ(s, a, k, kappa): k: bit length of a """ + movs(s, program.non_linear.ltz(a, k, kappa)) + +def LtzRing(a, k): from .types import sint, _bitint from .GC.types import sbitvec if program.use_split(): summands = a.split_to_two_summands(k) carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] - movs(s, sint.conv(msb)) - return - elif program.options.ring: + return sint.conv(msb) + else: from . import floatingpoint require_ring_size(k, 'comparison') m = k - 1 shift = int(program.options.ring) - k r_prime, r_bin = MaskingBitsInRing(k) tmp = a - r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) - movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))) - return - t = sint() - Trunc(t, a, k, k - 1, kappa, True) - subsfi(s, t, 0) + return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)) def LessThanZero(a, k, kappa): from . import types @@ -191,7 +189,7 @@ def TruncLeakyInRing(a, k, m, signed): r = sint.bit_compose(r_bits) if signed: a += (1 << (k - 1)) - shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal() + shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False) masked = shifted >> n_shift u = sint() BitLTL(u, masked, r_bits[:n_bits], 0) @@ -232,7 +230,7 @@ def Mod2mRing(a_prime, a, k, m, signed): shift = int(program.options.ring) - m r_prime, r_bin = MaskingBitsInRing(m, True) tmp = a + r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift u = sint() BitLTL(u, c_prime, r_bin[:m], 0) res = (u << m) + c_prime - r_prime @@ -262,7 +260,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) modc(c_prime, c, c2m) if const_rounds: BitLTC1(u, c_prime, r, kappa) @@ -293,7 +291,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): """ program.curr_tape.require_bit_length(k + kappa) from .types import sint - if program.use_edabit() and m > 1 and not const_rounds: + if program.use_edabit() and not const_rounds: movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0]) tmp, b[:] = sint.get_edabit(m, True) movs(r_prime, tmp) @@ -511,7 +509,7 @@ def PreMulC_with_inverses_and_vectors(p, a): movs(w[0], r[0]) movs(a_vec[0], a[0]) vmuls(k, t[0], w, a_vec) - vasm_open(k, m, t[0]) + vasm_open(k, True, m, t[0]) PreMulC_end(p, a, c, m, z) def PreMulC_with_inverses(p, a): @@ -539,7 +537,7 @@ def PreMulC_with_inverses(p, a): w[1][0] = r[0][0] for i in range(k): muls(t[0][i], w[1][i], a[i]) - asm_open(m[i], t[0][i]) + asm_open(True, m[i], t[0][i]) PreMulC_end(p, a, c, m, z) def PreMulC_without_inverses(p, a): @@ -564,7 +562,7 @@ def PreMulC_without_inverses(p, a): #adds(tt[0][i], t[0][i], a[i]) #subs(tt[1][i], tt[0][i], a[i]) #startopen(tt[1][i]) - asm_open(u[i], t[0][i]) + asm_open(True, u[i], t[0][i]) for i in range(k-1): muls(v[i], r[i+1], s[i]) w[0] = r[0] @@ -580,7 +578,7 @@ def PreMulC_without_inverses(p, a): mulm(z[i], s[i], u_inv[i]) for i in range(k): muls(t[1][i], w[i], a[i]) - asm_open(m[i], t[1][i]) + asm_open(True, m[i], t[1][i]) PreMulC_end(p, a, c, m, z) def PreMulC_end(p, a, c, m, z): @@ -638,6 +636,7 @@ def Mod2(a_0, a, k, kappa, signed): t = [program.curr_block.new_reg('s') for i in range(6)] c2k1 = program.curr_block.new_reg('c') PRandM(r_dprime, r_prime, [r_0], k, 1, kappa) + r_0 = r_prime mulsi(t[0], r_dprime, 2) if signed: ld2i(c2k1, k - 1) @@ -646,7 +645,7 @@ def Mod2(a_0, a, k, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) from . import floatingpoint c_0 = floatingpoint.bits(c, 1)[0] mulci(tc, c_0, 2) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 64e76434c..bb80dc344 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,94 +1,567 @@ -from Compiler.program import Program +import inspect +import os +import re +import sys +import tempfile +import subprocess +from optparse import OptionParser + +from Compiler.exceptions import CompilerError + from .GC import types as GC_types +from .program import Program, defaults -import sys -import re, tempfile, os - - -def run(args, options): - """ Compile a file and output a Program object. - - If options.merge_opens is set to True, will attempt to merge any - parallelisable open instructions. """ - - prog = Program(args, options) - VARS['program'] = prog - if options.binary: - VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) - VARS['sfix'] = GC_types.sbitfixvec - for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \ - 'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \ - 'squant': - del VARS[i] - - print('Compiling file', prog.infile) - f = open(prog.infile, 'rb') - - changed = False - if options.flow_optimization: - output = [] - if_stack = [] - for line in open(prog.infile): - if if_stack and not re.match(if_stack[-1][0], line): - if_stack.pop() - m = re.match( - '(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):', - line) - if m: - output.append('%s@for_range_opt(%s)\n' % (m.group(1), - m.group(3))) - output.append('%sdef _(%s):\n' % (m.group(1), m.group(2))) - changed = True - continue - m = re.match('(\s*)if(\W.*):', line) - if m: - if_stack.append((m.group(1), len(output))) - output.append('%s@if_(%s)\n' % (m.group(1), m.group(2))) - output.append('%sdef _():\n' % (m.group(1))) - changed = True - continue - m = re.match('(\s*)elif\s+', line) - if m: - raise CompilerError('elif not supported') - if if_stack: - m = re.match('%selse:' % if_stack[-1][0], line) - if m: - start = if_stack[-1][1] - ws = if_stack[-1][0] - output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws, - output[start]) - output.append('%s@else_\n' % ws) - output.append('%sdef _():\n' % ws) - continue - output.append(line) - if changed: - infile = tempfile.NamedTemporaryFile('w+', delete=False) - for line in output: - infile.write(line) - infile.seek(0) + +class Compiler: + def __init__(self, custom_args=None, usage=None, execute=False): + if usage: + self.usage = usage else: - infile = open(prog.infile) - else: - infile = open(prog.infile) + self.usage = "usage: %prog [options] filename [args]" + self.execute = execute + self.custom_args = custom_args + self.build_option_parser() + self.VARS = {} + + def build_option_parser(self): + parser = OptionParser(usage=self.usage) + parser.add_option( + "-n", + "--nomerge", + action="store_false", + dest="merge_opens", + default=defaults.merge_opens, + help="don't attempt to merge open instructions", + ) + parser.add_option("-o", "--output", dest="outfile", help="specify output file") + parser.add_option( + "-a", + "--asm-output", + dest="asmoutfile", + help="asm output file for debugging", + ) + parser.add_option( + "-g", + "--galoissize", + dest="galois", + default=defaults.galois, + help="bit length of Galois field", + ) + parser.add_option( + "-d", + "--debug", + action="store_true", + dest="debug", + help="keep track of trace for debugging", + ) + parser.add_option( + "-c", + "--comparison", + dest="comparison", + default="log", + help="comparison variant: log|plain|inv|sinv", + ) + parser.add_option( + "-M", + "--preserve-mem-order", + action="store_true", + dest="preserve_mem_order", + default=defaults.preserve_mem_order, + help="preserve order of memory instructions; possible efficiency loss", + ) + parser.add_option( + "-O", + "--optimize-hard", + action="store_true", + dest="optimize_hard", + help="lower number of rounds at higher compilation cost " + "(disables -C and increases the budget to 100000)", + ) + parser.add_option( + "-u", + "--noreallocate", + action="store_true", + dest="noreallocate", + default=defaults.noreallocate, + help="don't reallocate", + ) + parser.add_option( + "-m", + "--max-parallel-open", + dest="max_parallel_open", + default=defaults.max_parallel_open, + help="restrict number of parallel opens", + ) + parser.add_option( + "-D", + "--dead-code-elimination", + action="store_true", + dest="dead_code_elimination", + default=defaults.dead_code_elimination, + help="eliminate instructions with unused result", + ) + parser.add_option( + "-p", + "--profile", + action="store_true", + dest="profile", + help="profile compilation", + ) + parser.add_option( + "-s", + "--stop", + action="store_true", + dest="stop", + help="stop on register errors", + ) + parser.add_option( + "-R", + "--ring", + dest="ring", + default=defaults.ring, + help="bit length of ring (default: 0 for field)", + ) + parser.add_option( + "-B", + "--binary", + dest="binary", + default=defaults.binary, + help="bit length of sint in binary circuit (default: 0 for arithmetic)", + ) + parser.add_option( + "-G", + "--garbled-circuit", + dest="garbled", + action="store_true", + help="compile for binary circuits only (default: false)", + ) + parser.add_option( + "-F", + "--field", + dest="field", + default=defaults.field, + help="bit length of sint modulo prime (default: 64)", + ) + parser.add_option( + "-P", + "--prime", + dest="prime", + default=defaults.prime, + help="prime modulus (default: not specified)", + ) + parser.add_option( + "-I", + "--insecure", + action="store_true", + dest="insecure", + help="activate insecure functionality for benchmarking", + ) + parser.add_option( + "-b", + "--budget", + dest="budget", + help="set budget for optimized loop unrolling (default: %d)" % \ + defaults.budget, + ) + parser.add_option( + "-X", + "--mixed", + action="store_true", + dest="mixed", + help="mixing arithmetic and binary computation", + ) + parser.add_option( + "-Y", + "--edabit", + action="store_true", + dest="edabit", + help="mixing arithmetic and binary computation using edaBits", + ) + parser.add_option( + "-Z", + "--split", + default=defaults.split, + dest="split", + help="mixing arithmetic and binary computation " + "using direct conversion if supported " + "(number of parties as argument)", + ) + parser.add_option( + "--invperm", + action="store_true", + dest="invperm", + help="speedup inverse permutation (only use in two-party, " + "semi-honest environment)" + ) + parser.add_option( + "-C", + "--CISC", + action="store_true", + dest="cisc", + help="faster CISC compilation mode " + "(used by default unless -O is given)", + ) + parser.add_option( + "-K", + "--keep-cisc", + dest="keep_cisc", + help="don't translate CISC instructions", + ) + parser.add_option( + "-l", + "--flow-optimization", + action="store_true", + dest="flow_optimization", + help="optimize control flow", + ) + parser.add_option( + "-v", + "--verbose", + action="store_true", + dest="verbose", + help="more verbose output", + ) + if self.execute: + parser.add_option( + "-E", + "--execute", + dest="execute", + help="protocol to execute with", + ) + parser.add_option( + "-H", + "--hostfile", + dest="hostfile", + help="hosts to execute with", + ) + self.parser = parser + + def parse_args(self): + self.options, self.args = self.parser.parse_args(self.custom_args) + if self.execute: + if not self.options.execute: + raise CompilerError("must give name of protocol with '-E'") + protocol = self.options.execute + if protocol.find("ring") >= 0 or protocol.find("2k") >= 0 or \ + protocol.find("brain") >= 0 or protocol == "emulate": + if not (self.options.ring or self.options.binary): + self.options.ring = "64" + if self.options.field: + raise CompilerError( + "field option not compatible with %s" % protocol) + else: + if protocol.find("bin") >= 0 or protocol.find("ccd") >= 0 or \ + protocol.find("bmr") >= 0 or \ + protocol in ("replicated", "tinier", "tiny", "yao"): + if not self.options.binary: + self.options.binary = "32" + if self.options.ring or self.options.field: + raise CompilerError( + "ring/field options not compatible with %s" % + protocol) + if self.options.ring: + raise CompilerError( + "ring option not compatible with %s" % protocol) + if protocol == "emulate": + self.options.keep_cisc = '' + + def build_program(self, name=None): + self.prog = Program(self.args, self.options, name=name) + if self.execute: + if self.options.execute in \ + ("emulate", "ring", "rep-field", "semi2k"): + self.prog.use_trunc_pr = True + if self.options.execute in ("ring",): + self.prog.use_split(3) + if self.options.execute in ("semi2k",): + self.prog.use_split(2) + if self.options.execute in ("rep4-ring",): + self.prog.use_split(4) + + def build_vars(self): + from . import comparison, floatingpoint, instructions, library, types + + # add all instructions to the program VARS dictionary + instr_classes = [ + t[1] for t in inspect.getmembers(instructions, inspect.isclass) + ] + + for mod in (types, GC_types): + instr_classes += [ + t[1] + for t in inspect.getmembers(mod, inspect.isclass) + if t[1].__module__ == mod.__name__ + ] + + instr_classes += [ + t[1] + for t in inspect.getmembers(library, inspect.isfunction) + if t[1].__module__ == library.__name__ + ] + + for op in instr_classes: + self.VARS[op.__name__] = op + + # backward compatibility for deprecated classes + self.VARS["sbitint"] = GC_types.sbitintvec + self.VARS["sbitfix"] = GC_types.sbitfixvec + + # add open and input separately due to name conflict + self.VARS["vopen"] = instructions.vasm_open + self.VARS["gopen"] = instructions.gasm_open + self.VARS["vgopen"] = instructions.vgasm_open + self.VARS["ginput"] = instructions.gasm_input + + self.VARS["comparison"] = comparison + self.VARS["floatingpoint"] = floatingpoint + + self.VARS["program"] = self.prog + if self.options.binary: + self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary)) + self.VARS["sfix"] = GC_types.sbitfixvec + for i in [ + "cint", + "cfix", + "cgf2n", + "sintbit", + "sgf2n", + "sgf2nint", + "sgf2nuint", + "sgf2nuint32", + "sgf2nfloat", + "cfloat", + "squant", + ]: + del self.VARS[i] + + def prep_compile(self, name=None, build=True): + self.parse_args() + if len(self.args) < 1 and name is None: + self.parser.print_help() + exit(1) + if build: + self.build(name=name) + + def build(self, name=None): + self.build_program(name=name) + self.build_vars() - # make compiler modules directly accessible - sys.path.insert(0, 'Compiler') - # create the tapes - exec(compile(infile.read(), infile.name, 'exec'), VARS) + def compile_file(self): + """Compile a file and output a Program object. - if changed and not options.debug: - os.unlink(infile.name) + If options.merge_opens is set to True, will attempt to merge any + parallelisable open instructions.""" + print("Compiling file", self.prog.infile) - prog.finalize() + with open(self.prog.infile, "r") as f: + changed = False + if self.options.flow_optimization: + output = [] + if_stack = [] + for line in f: + if if_stack and not re.match(if_stack[-1][0], line): + if_stack.pop() + m = re.match( + r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_.]+)\):", + line, + ) + if m: + output.append( + "%s@for_range_opt(%s)\n" % (m.group(1), m.group(3)) + ) + output.append("%sdef _(%s):\n" % (m.group(1), m.group(2))) + changed = True + continue + m = re.match(r"(\s*)if(\W.*):", line) + if m: + if_stack.append((m.group(1), len(output))) + output.append("%s@if_(%s)\n" % (m.group(1), m.group(2))) + output.append("%sdef _():\n" % (m.group(1))) + changed = True + continue + m = re.match(r"(\s*)elif\s+", line) + if m: + raise CompilerError("elif not supported") + if if_stack: + m = re.match("%selse:" % if_stack[-1][0], line) + if m: + start = if_stack[-1][1] + ws = if_stack[-1][0] + output[start] = re.sub( + r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start] + ) + output.append("%s@else_\n" % ws) + output.append("%sdef _():\n" % ws) + continue + output.append(line) + if changed: + infile = tempfile.NamedTemporaryFile("w+", delete=False) + for line in output: + infile.write(line) + infile.seek(0) + else: + infile = open(self.prog.infile) + else: + infile = open(self.prog.infile) - if prog.req_num: - print('Program requires:') - for x in prog.req_num.pretty(): - print(x) + # make compiler modules directly accessible + sys.path.insert(0, "Compiler") + # create the tapes + exec(compile(infile.read(), infile.name, "exec"), self.VARS) - if prog.verbose: - print('Program requires:', repr(prog.req_num)) - print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) - print('Memory size:', dict(prog.allocated_mem)) + if changed and not self.options.debug: + os.unlink(infile.name) - return prog + return self.finalize_compile() + + def register_function(self, name=None): + """ + To register a function to be compiled, use this as a decorator. + Example: + + @compiler.register_function('test-mpc') + def test_mpc(compiler): + ... + """ + + def inner(func): + self.compile_name = name or func.__name__ + self.compile_function = func + return func + + return inner + + def compile_func(self): + if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")): + raise CompilerError( + "No function to compile. " + "Did you decorate a function with @register_fuction(name)?" + ) + self.prep_compile(self.compile_name) + print( + "Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__) + ) + self.compile_function() + self.finalize_compile() + + def finalize_compile(self): + self.prog.finalize() + + if self.prog.req_num: + print("Program requires at most:") + for x in self.prog.req_num.pretty(): + print(x) + + if self.prog.verbose: + print("Program requires:", repr(self.prog.req_num)) + print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost()) + print("Memory size:", dict(self.prog.allocated_mem)) + + return self.prog + + @staticmethod + def executable_from_protocol(protocol): + match = { + "ring": "replicated-ring", + "rep-field": "replicated-field", + "replicated": "replicated-bin" + } + if protocol in match: + protocol = match[protocol] + if protocol.find("bmr") == -1: + protocol = re.sub("^mal-", "malicious-", protocol) + if protocol == "emulate": + return protocol + ".x" + else: + return protocol + "-party.x" + + def local_execution(self, args=[]): + executable = self.executable_from_protocol(self.options.execute) + if not os.path.exists(executable): + print("Creating binary for virtual machine...") + try: + subprocess.run(["make", executable], check=True) + except: + raise CompilerError( + "Cannot produce %s. " % executable + \ + "Note that compilation requires a few GB of RAM.") + vm = 'Scripts/%s.sh' % self.options.execute + os.execl(vm, vm, self.prog.name, *args) + + def remote_execution(self, args=[]): + vm = self.executable_from_protocol(self.options.execute) + hosts = list(x.strip() + for x in filter(None, open(self.options.hostfile))) + # test availability before compilation + from fabric import Connection + import subprocess + print("Creating static binary for virtual machine...") + subprocess.run(["make", "static/%s" % vm], check=True) + + # transfer files + import glob + hostnames = [] + destinations = [] + for host in hosts: + split = host.split('/', maxsplit=1) + hostnames.append(split[0]) + if len(split) > 1: + destinations.append(split[1]) + else: + destinations.append('.') + connections = [Connection(hostname) for hostname in hostnames] + print("Setting up players...") + + def run(i): + dest = destinations[i] + connection = connections[i] + connection.run( + "mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \ + dest) + # executable + connection.put("static/%s" % vm, dest) + # program + dest += "/" + connection.put("Programs/Schedules/%s.sch" % self.prog.name, + dest + "Programs/Schedules") + for filename in glob.glob( + "Programs/Bytecode/%s-*.bc" % self.prog.name): + connection.put(filename, dest + "Programs/Bytecode") + # inputs + for filename in glob.glob("Player-Data/Input*-P%d-*" % i): + connection.put(filename, dest + "Player-Data") + # key and certificates + for suffix in ('key', 'pem'): + connection.put("Player-Data/P%d.%s" % (i, suffix), + dest + "Player-Data") + for filename in glob.glob("Player-Data/*.0"): + connection.put(filename, dest + "Player-Data") + + import threading + import random + threads = [] + for i in range(len(hosts)): + threads.append(threading.Thread(target=run, args=(i,))) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # execution + threads = [] + # random port numbers to avoid conflict + port = 10000 + random.randrange(40000) + if '@' in hostnames[0]: + party0 = hostnames[0].split('@')[1] + else: + party0 = hostnames[0] + for i in range(len(connections)): + run = lambda i: connections[i].run( + "cd %s; ./%s -p %d %s -h %s -pn %d %s" % \ + (destinations[i], vm, i, self.prog.name, party0, port, + ' '.join(args))) + threads.append(threading.Thread(target=run, args=(i,))) + for thread in threads: + thread.start() + for thread in threads: + thread.join() diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py new file mode 100644 index 000000000..7e25f1591 --- /dev/null +++ b/Compiler/decision_tree.py @@ -0,0 +1,645 @@ +from Compiler.types import * +from Compiler.sorting import * +from Compiler.library import * +from Compiler import util, oram + +from itertools import accumulate +import math + +debug = False +debug_split = False +max_leaves = None + +def get_type(x): + if isinstance(x, (Array, SubMultiArray)): + return x.value_type + elif isinstance(x, (tuple, list)): + x = x[0] + x[-1] + if util.is_constant(x): + return cint + else: + return type(x) + else: + return type(x) + +def PrefixSum(x): + return x.get_vector().prefix_sum() + +def PrefixSumR(x): + tmp = get_type(x).Array(len(x)) + tmp.assign_vector(x) + break_point() + tmp[:] = tmp.get_reverse_vector().prefix_sum() + break_point() + return tmp.get_reverse_vector() + +def PrefixSum_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x, base=1) + tmp[0] = 0 + return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x)) + +def PrefixSumR_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x) + tmp[-1] = 0 + return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x)) + +class SortPerm: + def __init__(self, x): + B = sint.Matrix(len(x), 2) + B.set_column(0, 1 - x.get_vector()) + B.set_column(1, x.get_vector()) + self.perm = Array.create_from(dest_comp(B)) + def apply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, False) + return res + def unapply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, True) + return res + +def Sort(keys, *to_sort, n_bits=None, time=False): + if time: + start_timer(1) + for k in keys: + assert len(k) == len(keys[0]) + n_bits = n_bits or [None] * len(keys) + bs = Matrix.create_from( + sum([k.get_vector().bit_decompose(nb) + for k, nb in reversed(list(zip(keys, n_bits)))], [])) + get_vec = lambda x: x[:] if isinstance(x, Array) else x + res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x + for x in to_sort) + res = res.transpose() + if time: + start_timer(11) + radix_sort_from_matrix(bs, res) + if time: + stop_timer(11) + stop_timer(1) + res = res.transpose() + return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f) + if isinstance(get_vec(y), sfix) + else x for (x, y) in zip(res, to_sort)] + +def VectMax(key, *data, debug=False): + def reducer(x, y): + b = x[0] > y[0] + if debug: + print_ln('max b=%s', b.reveal()) + return [b.if_else(xx, yy) for xx, yy in zip(x, y)] + if debug: + key = list(key) + data = [list(x) for x in data] + print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data)) + res = util.tree_reduce(reducer, zip(key, *data))[1:] + if debug: + print_ln('vect max res=%s', util.reveal(res)) + return res + +def GroupSum(g, x): + assert len(g) == len(x) + p = PrefixSumR(x) * g + pi = SortPerm(g.get_vector().bit_not()) + p1 = pi.apply(p) + s1 = PrefixSumR_inv(p1) + d1 = PrefixSum_inv(s1) + d = pi.unapply(d1) * g + return PrefixSum(d) + +def GroupPrefixSum(g, x): + assert len(g) == len(x) + s = get_type(x).Array(len(x) + 1) + s[0] = 0 + s.assign_vector(PrefixSum(x), base=1) + q = get_type(s).Array(len(x)) + q.assign_vector(s.get_vector(size=len(x)) * g) + return s.get_vector(size=len(x), base=1) - GroupSum(g, q) + +def GroupMax(g, keys, *x): + if debug: + print_ln('group max input g=%s keys=%s x=%s', util.reveal(g), + util.reveal(keys), util.reveal(x)) + assert len(keys) == len(g) + for xx in x: + assert len(xx) == len(g) + n = len(g) + m = int(math.ceil(math.log(n, 2))) + keys = Array.create_from(keys) + x = [Array.create_from(xx) for xx in x] + g_new = Array.create_from(g) + g_old = g_new.same_shape() + for d in range(m): + w = 2 ** d + g_old[:] = g_new[:] + break_point() + vsize = n - w + g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( + g_old.get_vector(size=vsize, base=w)), base=w) + b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) + for xx in [keys] + x: + a = b.if_else(xx.get_vector(size=vsize), + xx.get_vector(size=vsize, base=w)) + xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( + xx.get_vector(size=vsize, base=w), a), base=w) + break_point() + if debug: + print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(), + util.reveal(a), util.reveal(keys), + util.reveal(x), g_new.reveal()) + t = sint.Array(len(g)) + t[-1] = 1 + t.assign_vector(g.get_vector(size=n - 1, base=1)) + if debug: + print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g), + util.reveal(t), util.reveal(keys), util.reveal(x)) + return [GroupSum(g, t[:] * xx) for xx in [keys] + x] + +def ModifiedGini(g, y, debug=False): + assert len(g) == len(y) + y = [y.get_vector().bit_not(), y] + u = [GroupPrefixSum(g, yy) for yy in y] + s = [GroupSum(g, yy) for yy in y] + w = [ss - uu for ss, uu in zip(s, u)] + us = sum(u) + ws = sum(w) + uqs = u[0] ** 2 + u[1] ** 2 + wqs = w[0] ** 2 + w[1] ** 2 + res = sfix(uqs) / us + sfix(wqs) / ws + if debug: + print_ln('g=%s y=%s s=%s', + util.reveal(g), util.reveal(y), + util.reveal(s)) + print_ln('u0=%s', util.reveal(u[0])) + print_ln('u0=%s', util.reveal(u[1])) + print_ln('us=%s', util.reveal(us)) + print_ln('w0=%s', util.reveal(w[0])) + print_ln('w1=%s', util.reveal(w[1])) + print_ln('ws=%s', util.reveal(ws)) + print_ln('uqs=%s', util.reveal(uqs)) + print_ln('wqs=%s', util.reveal(wqs)) + if debug: + print_ln('gini %s %s', type(res), util.reveal(res)) + return res + +MIN_VALUE = -10000 + +def FormatLayer(h, g, *a): + return CropLayer(h, *FormatLayer_without_crop(g, *a)) + +def FormatLayer_without_crop(g, *a, debug=False): + for x in a: + assert len(x) == len(g) + v = [g.if_else(aa, 0) for aa in a] + if debug: + print_ln('format in %s', util.reveal(a)) + print_ln('format mux %s', util.reveal(v)) + v = Sort([g.bit_not()], *v, n_bits=[1]) + if debug: + print_ln('format sort %s', util.reveal(v)) + return v + +def CropLayer(k, *v): + if max_leaves: + n = min(2 ** k, max_leaves) + else: + n = 2 ** k + return [vv[:min(n, len(vv))] for vv in v] + +def TrainLeafNodes(h, g, y, NID): + assert len(g) == len(y) + assert len(g) == len(NID) + Label = GroupSum(g, y.bit_not()) < GroupSum(g, y) + return FormatLayer(h, g, NID, Label) + +def GroupSame(g, y): + assert len(g) == len(y) + s = GroupSum(g, [sint(1)] * len(g)) + s0 = GroupSum(g, y.bit_not()) + s1 = GroupSum(g, y) + if debug_split: + print_ln('group same g=%s', util.reveal(g)) + print_ln('group same y=%s', util.reveal(y)) + return (s == s0).bit_or(s == s1) + +def GroupFirstOne(g, b): + assert len(g) == len(b) + s = GroupPrefixSum(g, b) + return s * b == 1 + +class TreeTrainer: + """ Decision tree training by `Hamada et al.`_ + + :param x: sample data (by attribute, list or + :py:obj:`~Compiler.types.Matrix`) + :param y: binary labels (list or sint vector) + :param h: height (int) + :param binary: binary attributes instead of continuous + :param attr_lengths: attribute description for mixed data + (list of 0/1 for continuous/binary) + :param n_threads: number of threads (default: single thread) + + .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 + + """ + def ApplyTests(self, x, AID, Threshold): + m = len(x) + n = len(AID) + assert len(AID) == len(Threshold) + for xx in x: + assert len(xx) == len(AID) + e = sint.Matrix(m, n) + AID = Array.create_from(AID) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + e[j][:] = AID[:] == j + xx = sum(x[j] * e[j] for j in range(m)) + if self.debug > 1: + print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx)) + print_ln('threshold %s', util.reveal(Threshold)) + return 2 * xx < Threshold + + def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False): + assert len(g) == len(x) + assert len(g) == len(y) + if time: + start_timer(2) + s = ModifiedGini(g, y, debug=debug or self.debug > 2) + if time: + stop_timer(2) + if debug or self.debug > 1: + print_ln('gini %s', s.reveal()) + xx = x + t = get_type(x).Array(len(x)) + t[-1] = MIN_VALUE + t.assign_vector(xx.get_vector(size=len(x) - 1) + \ + xx.get_vector(size=len(x) - 1, base=1)) + gg = g + p = sint.Array(len(x)) + p[-1] = 1 + p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( + xx.get_vector(size=len(x) - 1) == \ + xx.get_vector(size=len(x) - 1, base=1))) + break_point() + if debug: + print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p)) + s = p[:].if_else(MIN_VALUE, s) + t = p[:].if_else(MIN_VALUE, t[:]) + if debug: + print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) + if time: + start_timer(3) + s, t = GroupMax(gg, s, t) + if time: + stop_timer(3) + if debug: + print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) + return t, s + + def GlobalTestSelection(self, x, y, g): + assert len(y) == len(g) + for xx in x: + assert(len(xx) == len(g)) + m = len(x) + n = len(y) + u, t = [get_type(x).Matrix(m, n) for i in range(2)] + v = get_type(y).Matrix(m, n) + s = sfix.Matrix(m, n) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + single = not self.n_threads or self.n_threads == 1 + time = self.time and single + if debug: + print_ln('run %s', j) + @if_e(self.attr_lengths[j]) + def _(): + u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, + n_bits=[util.log2(n), 1], time=time) + @else_ + def _(): + u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, + n_bits=[util.log2(n), None], + time=time) + if self.debug_threading: + print_ln('global sort %s %s %s', j, util.reveal(u[j]), + util.reveal(v[j])) + t[j][:], s[j][:] = self.AttributeWiseTestSelection( + g, u[j], v[j], time=time, debug=self.debug_selection) + if self.debug_threading: + print_ln('global attribute %s %s %s', j, util.reveal(t[j]), + util.reveal(s[j])) + n = len(g) + a = sint.Array(n) + if self.debug_threading: + print_ln('global s=%s', util.reveal(s)) + if self.debug_gini: + print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)), + *(ss[0].reveal() for ss in s)) + if self.time: + start_timer(4) + if self.debug > 1: + print_ln('s=%s', s.reveal_nested()) + print_ln('t=%s', t.reveal_nested()) + a[:], tt = VectMax((s[j][:] for j in range(m)), range(m), + (t[j][:] for j in range(m)), debug=self.debug > 1) + tt = Array.create_from(tt) + if self.time: + stop_timer(4) + if self.debug > 1: + print_ln('a=%s', util.reveal(a)) + print_ln('tt=%s', util.reveal(tt)) + return a[:], tt[:] + + def TrainInternalNodes(self, k, x, y, g, NID): + assert len(g) == len(y) + for xx in x: + assert len(xx) == len(g) + AID, Threshold = self.GlobalTestSelection(x, y, g) + s = GroupSame(g[:], y[:]) + if self.debug > 1 or debug_split: + print_ln('AID=%s', util.reveal(AID)) + print_ln('Threshold=%s', util.reveal(Threshold)) + print_ln('GroupSame=%s', util.reveal(s)) + AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold) + if self.debug > 1 or debug_split: + print_ln('AID=%s', util.reveal(AID)) + print_ln('Threshold=%s', util.reveal(Threshold)) + b = self.ApplyTests(x, AID, Threshold) + layer = FormatLayer_without_crop(g[:], NID, AID, Threshold, + debug=self.debug > 1) + return *layer, b + + @method_block + def train_layer(self, k): + x = self.x + y = self.y + g = self.g + NID = self.NID + if self.debug > 1: + print_ln('g=%s', g.reveal()) + print_ln('y=%s', y.reveal()) + print_ln('x=%s', x.reveal_nested()) + self.nids[k], self.aids[k], self.thresholds[k], b = \ + self.TrainInternalNodes(k, x, y, g, NID) + if self.debug > 1: + print_ln('layer %s:', k) + for name, data in zip(('NID', 'AID', 'Thr'), + (self.nids[k], self.aids[k], + self.thresholds[k])): + print_ln(' %s: %s', name, data.reveal()) + NID[:] = 2 ** k * b + NID + b_not = b.bit_not() + if self.debug > 1: + print_ln('b_not=%s', b_not.reveal()) + g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b) + y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1]) + for i, xxx in enumerate(xx): + x[i] = xxx + + def __init__(self, x, y, h, binary=False, attr_lengths=None, + n_threads=None): + assert not (binary and attr_lengths) + if binary: + attr_lengths = [1] * len(x) + else: + attr_lengths = attr_lengths or ([0] * len(x)) + for l in attr_lengths: + assert l in (0, 1) + self.attr_lengths = Array.create_from(regint(attr_lengths)) + Array.check_indices = False + Matrix.disable_index_checks() + for xx in x: + assert len(xx) == len(y) + n = len(y) + self.g = sint.Array(n) + self.g.assign_all(0) + self.g[0] = 1 + self.NID = sint.Array(n) + self.NID.assign_all(1) + self.y = Array.create_from(y) + self.x = Matrix.create_from(x) + self.nids, self.aids = [sint.Matrix(h, n) for i in range(2)] + self.thresholds = self.x.value_type.Matrix(h, n) + self.n_threads = n_threads + self.debug_selection = False + self.debug_threading = False + self.debug_gini = False + self.debug = False + self.time = False + + def train(self): + """ Train and return decision tree. """ + h = len(self.nids) + @for_range(h) + def _(k): + self.train_layer(k) + return self.get_tree(h) + + def train_with_testing(self, *test_set, output=False): + """ Train decision tree and test against test data. + + :param y: binary labels (list or sint vector) + :param x: sample data (by attribute, list or + :py:obj:`~Compiler.types.Matrix`) + :param output: output tree after every level + :returns: tree + + """ + for k in range(len(self.nids)): + self.train_layer(k) + tree = self.get_tree(k + 1) + if output: + output_decision_tree(tree) + test_decision_tree('train', tree, self.y, self.x, + n_threads=self.n_threads) + if test_set: + test_decision_tree('test', tree, *test_set, + n_threads=self.n_threads) + return tree + + def get_tree(self, h): + Layer = [None] * (h + 1) + for k in range(h): + Layer[k] = CropLayer(k, self.nids[k], self.aids[k], + self.thresholds[k]) + Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID) + return Layer + +def DecisionTreeTraining(x, y, h, binary=False): + return TreeTrainer(x, y, h, binary=binary).train() + +def output_decision_tree(layers): + """ Print decision tree output by :py:class:`TreeTrainer`. """ + print_ln('full model %s', util.reveal(layers)) + for i, layer in enumerate(layers[:-1]): + print_ln('level %s:', i) + for j, x in enumerate(('NID', 'AID', 'Thr')): + print_ln(' %s: %s', x, util.reveal(layer[j])) + print_ln('leaves:') + for j, x in enumerate(('NID', 'result')): + print_ln(' %s: %s', x, util.reveal(layers[-1][j])) + +def pick(bits, x): + if len(bits) == 1: + return bits[0] * x[0] + else: + try: + return x[0].dot_product(bits, x) + except: + return sum(aa * bb for aa, bb in zip(bits, x)) + +def run_decision_tree(layers, data): + """ Run decision tree against sample data. + + :param layers: tree output by :py:class:`TreeTrainer` + :param data: sample data (:py:class:`~Compiler.types.Array`) + :returns: binary label + + """ + h = len(layers) - 1 + index = 1 + for k, layer in enumerate(layers[:-1]): + assert len(layer) == 3 + for x in layer: + assert len(x) <= 2 ** k + bits = layer[0].equal(index, k) + threshold = pick(bits, layer[2]) + key_index = pick(bits, layer[1]) + if key_index.is_clear: + key = data[key_index] + else: + key = pick( + oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) + child = 2 * key < threshold + index += child * 2 ** k + bits = layers[h][0].equal(index, h) + return pick(bits, layers[h][1]) + +def test_decision_tree(name, layers, y, x, n_threads=None, time=False): + if time: + start_timer(100) + n = len(y) + x = x.transpose().reveal() + y = y.reveal() + guess = regint.Array(n) + truth = regint.Array(n) + correct = regint.Array(2) + parts = regint.Array(2) + layers = [[Array.create_from(util.reveal(x)) for x in layer] + for layer in layers] + @for_range_multithread(n_threads, 1, n) + def _(i): + guess[i] = run_decision_tree([[part[:] for part in layer] + for layer in layers], x[i]).reveal() + truth[i] = y[i].reveal() + @for_range(n) + def _(i): + parts[truth[i]] += 1 + c = (guess[i].bit_xor(truth[i]).bit_not()) + correct[truth[i]] += c + print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, + sum(correct), n, correct[0], parts[0], correct[1], parts[1]) + if time: + stop_timer(100) + +class TreeClassifier: + """ Tree classification with convenient interface. Uses + :py:class:`TreeTrainer` internally. + + :param max_depth: the depth of the decision tree + + """ + def __init__(self, max_depth): + self.max_depth = max_depth + + @staticmethod + def get_attr_lengths(attr_types): + if attr_types == None: + return None + else: + return [1 if x == 'b' else 0 for x in attr_types] + + def fit(self, X, y, attr_types=None): + """ Train tree. + + :param X: sample data with row-wise samples (sint/sfix matrix) + :param y: binary labels (sint list/array) + + """ + self.tree = TreeTrainer( + X.transpose(), y, self.max_depth, + attr_lengths=self.get_attr_lengths(attr_types)).train() + + def fit_with_testing(self, X_train, y_train, X_test, y_test, + attr_types=None, output_trees=False, debug=False): + """ Train tree with accuracy output after every level. + + :param X_train: training data with row-wise samples (sint/sfix matrix) + :param y_train: training binary labels (sint list/array) + :param X_test: testing data with row-wise samples (sint/sfix matrix) + :param y_test: testing binary labels (sint list/array) + :param attr_types: attributes types (list of 'b'/'c' for + binary/continuous; default is all continuous) + :param output_trees: output tree after every level + :param debug: output debugging information + + """ + trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth, + attr_lengths=self.get_attr_lengths(attr_types)) + trainer.debug = debug + trainer.debug_gini = debug + trainer.debug_threading = debug > 1 + self.tree = trainer.train_with_testing(y_test, X_test.transpose(), + output=output_trees) + + def predict(self, X): + """ Use tree for prediction. + + :param X: sample data with row-wise samples (sint/sfix matrix) + :returns: sint array + + """ + res = sint.Array(len(X)) + @for_range(len(X)) + def _(i): + res[i] = run_decision_tree(self.tree, X[i]) + return res + + def output(self): + """ Output decision tree. """ + output_decision_tree(self.tree) + +def preprocess_pandas(data): + """ Preprocess pandas data frame to suit + :py:class:`TreeClassifier` by expanding non-continuous attributes + to several binary attributes as a unary encoding. + + :returns: a tuple of the processed data and a type list for the + :py:obj:`attr_types` argument. + + """ + import pandas + import numpy + res = [] + types = [] + for i, t in enumerate(data.dtypes): + if pandas.api.types.is_int64_dtype(t): + res.append(data.iloc[:,i].to_numpy()) + types.append('c') + elif pandas.api.types.is_object_dtype(t): + values = data.iloc[:,i].unique() + print('converting the following to unary:', values) + if len(values) == 2: + res.append(data.iloc[:,i].to_numpy() == values[1]) + types.append('b') + else: + for value in values: + res.append(data.iloc[:,i].to_numpy() == value) + types.append('b') + else: + raise CompilerError('unknown pandas type: ' + t) + res = numpy.array(res) + res = numpy.swapaxes(res, 0, 1) + return res, types diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index 45d25e6bf..fd57e1b31 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -99,7 +99,7 @@ def bubble_up(self, start): bits.reverse() bits = [0] + floatingpoint.PreOR(bits, self.levels) bits = [bits[i+1] - bits[i] for i in range(self.levels)] - shift = sum([bit << i for i,bit in enumerate(bits)]) + shift = self.int_type.bit_compose(bits) childpos = MemValue(start * shift) @for_range(self.levels - 1) def f(i): @@ -215,12 +215,13 @@ def dump(self, msg=''): print_ln() print_ln() -def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint): - basic_type = int_type.basic_type +def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None): vert_loops = n_loops * e_index.size // edges.size \ if n_loops else -1 dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \ - init_rounds=vert_loops, value_type=basic_type) + init_rounds=vert_loops, value_type=int_type) + int_type = dist.value_type + basic_type = int_type.basic_type #visited = ORAM(e_index.size) #previous = oram_type(e_index.size) Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \ @@ -240,7 +241,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint): u = MemValue(basic_type(0)) @for_range(n_loops or edges.size) def f(i): - cint(i).print_reg('loop') + print_ln('loop %s', i) time() u.write(if_else(last_edge, Q.pop(last_edge), u)) #visited.access(u, True, last_edge) diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py index fd0265637..c68ecd317 100644 --- a/Compiler/exceptions.py +++ b/Compiler/exceptions.py @@ -12,4 +12,7 @@ class ArgumentError(CompilerError): """ Exception raised for errors in instruction argument parsing. """ def __init__(self, arg, msg): self.arg = arg - self.msg = msg \ No newline at end of file + self.msg = msg + +class VectorMismatch(CompilerError): + pass diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 66de0859a..f44d95cbe 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -28,13 +28,15 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - if program.Program.prog.use_dabit: + if program.Program.prog.use_edabit(): + r_prime, r = types.sint.get_edabit(k) + elif program.Program.prog.use_dabit: rr, r = zip(*(types.sint.get_dabit() for i in range(k))) r_prime = types.sint.bit_compose(rr) else: r = [types.sint.get_random_bit() for i in range(k)] r_prime = types.sint.bit_compose(r) - c = ((a + r_prime) << shift).reveal() >> shift + c = ((a + r_prime) << shift).reveal(False) >> shift return c, r def maskField(a, k, kappa): @@ -45,7 +47,7 @@ def maskField(a, k, kappa): comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) # always signed due to usage in equality testing a += two_power(k) - asm_open(c, a + two_power(k) * r_dprime + r_prime) + asm_open(True, c, a + two_power(k) * r_dprime + r_prime) return c, r @instructions_base.ret_cisc @@ -231,7 +233,7 @@ def Inv(a): ldi(one, 1) inverse(t[0], t[1]) s = t[0]*a - asm_open(c[0], s) + asm_open(True, c[0], s) # avoid division by zero for benchmarking divc(c[1], one, c[0]) #divc(c[1], c[0], one) @@ -279,7 +281,7 @@ def BitDecRingRaw(a, k, m): else: r_bits = [types.sint.get_random_bit() for i in range(m)] r = types.sint.bit_compose(r_bits) - shifted = ((a - r) << n_shift).reveal() + shifted = ((a - r) << n_shift).reveal(False) masked = shifted >> n_shift bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) return bits @@ -290,14 +292,16 @@ def BitDecRing(a, k, m): return [types.sint.conv(bit) for bit in reversed(bits)][::-1] def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): + instructions_base.set_global_vector_size(a.size) r_dprime = types.sint() r_prime = types.sint() c = types.cint() r = [types.sint() for i in range(m)] comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) pow2 = two_power(k + kappa) - asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) + asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) + instructions_base.reset_global_vector_size() return res def BitDecField(a, k, m, kappa, bits_to_compute=None): @@ -307,6 +311,7 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): @instructions_base.ret_cisc def Pow2(a, l, kappa): + comparison.program.curr_tape.require_bit_length(l - 1) m = int(ceil(log(l, 2))) t = BitDec(a, m, m, kappa) return Pow2_from_bits(t) @@ -314,7 +319,7 @@ def Pow2(a, l, kappa): def Pow2_from_bits(bits): m = len(bits) t = list(bits) - pow2k = [types.cint() for i in range(m)] + pow2k = [None for i in range(m)] for i in range(m): pow2k[i] = two_power(2**i) t[i] = t[i]*pow2k[i] + 1 - t[i] @@ -337,10 +342,10 @@ def B2U_from_Pow2(pow2a, l, kappa): if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l assert n_shift > 0 - c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift + c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(t, kappa) - asm_open(c, pow2a + two_power(l) * t + + asm_open(True, c, pow2a + two_power(l) * t + sum(two_power(i) * r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(r_bits[0].bit_decompose_clear(c, l)) @@ -382,15 +387,15 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): r_dprime += t1 - t2 if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l - c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift + c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(rk, kappa) r_dprime += two_power(l) * rk - asm_open(c, a + r_dprime + r_prime) + asm_open(True, c, a + r_dprime + r_prime) for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) - lts(d, c_dprime, r_prime, l, kappa) + d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa) if compute_modulo: b = c_dprime - r_prime + pow2m * d return b, pow2m @@ -412,7 +417,7 @@ def TruncInRing(to_shift, l, pow2m): rev *= pow2m r_bits = [types.sint.get_random_bit() for i in range(l)] r = types.sint.bit_compose(r_bits) - shifted = (rev - (r << n_shift)).reveal() + shifted = (rev - (r << n_shift)).reveal(False) masked = shifted >> n_shift bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l)) return types.sint.bit_compose(reversed(bits)) @@ -453,7 +458,7 @@ def Int2FL(a, gamma, l, kappa=None): v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False) else: v = 2**(l-gamma+1) * t - p = (p + gamma - 1 - l) * (1 -z) + p = (p + gamma - 1 - l) * z.bit_not() return v, p, z, s def FLRound(x, mode): @@ -526,7 +531,7 @@ def TruncPrRing(a, k, m, signed=True): msb = r_bits[-1] n_shift = n_ring - (k + 1) tmp = a + r - masked = (tmp << n_shift).reveal() + masked = (tmp << n_shift).reveal(False) shifted = (masked << 1 >> (n_shift + m + 1)) overflow = msb.bit_xor(masked >> (n_ring - 1)) res = shifted - upper + \ @@ -547,7 +552,7 @@ def TruncPrField(a, k, m, kappa=None): k, m, kappa, use_dabit=False) two_to_m = two_power(m) r = two_to_m * r_dprime + r_prime - c = (b + r).reveal() + c = (b + r).reveal(False) c_prime = c % two_to_m a_prime = c_prime - r_prime d = (a - a_prime) / two_to_m @@ -627,14 +632,16 @@ def BITLT(a, b, bit_length): # - From the paper # Multiparty Computation for Interval, Equality, and Comparison without # Bit-Decomposition Protocol -def BitDecFull(a, maybe_mixed=False): +def BitDecFull(a, n_bits=None, maybe_mixed=False): from .library import get_program, do_while, if_, break_point from .types import sint, regint, longint, cint p = get_program().prime assert p bit_length = p.bit_length() + n_bits = n_bits or bit_length + assert n_bits <= bit_length logp = int(round(math.log(p, 2))) - if abs(p - 2 ** logp) / p < 2 ** -get_program().security: + if get_program().rabbit_gap(): # inspired by Rabbit (https://eprint.iacr.org/2021/119) # no need for exact randomness generation # if modulo a power of two is close enough @@ -661,26 +668,26 @@ def get_bits_loop(): def _(): for i in range(bit_length): tbits[j][i].link(sint.get_random_bit()) - c = regint(BITLT(tbits[j], pbits, bit_length).reveal()) + c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False)) done[j].link(c) return (sum(done) != a.size) for j in range(a.size): for i in range(bit_length): movs(bbits[i][j], tbits[j][i]) b = sint.bit_compose(bbits) - c = (a-b).reveal() + c = (a-b).reveal(False) cmodp = c t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) czero = (c==0) q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t) fbar = [bbits[0].clear_type.conv(cint(x)) - for x in ((1< 1 else 0) def add_usage(self, req_node): # player 0 as proxy req_node.increment((self.field_type, 'input', 0), float('inf')) + def get_players(self): + pass + @base.gf2n @base.vectorize class rawinput(base.RawInputInstruction, base.Mergeable): @@ -1431,7 +1486,23 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', player), \ self.get_size()) -class inputpersonal(base.Instruction, base.Mergeable): +class personal_base(base.Instruction, base.Mergeable): + __slots__ = [] + field_type = 'modp' + + def __init__(self, *args): + super(personal_base, self).__init__(*args) + for i in range(0, len(args), 4): + assert args[i + 2].size == args[i] + assert args[i + 3].size == args[i] + + def add_usage(self, req_node): + for i in range(0, len(self.args), 4): + player = self.args[i + 1] + req_node.increment((self.field_type, 'input', player), \ + self.args[i]) + +class inputpersonal(personal_base): """ Private input from cint. :param: vector size (int) @@ -1443,19 +1514,47 @@ class inputpersonal(base.Instruction, base.Mergeable): __slots__ = [] code = base.opcodes['INPUTPERSONAL'] arg_format = tools.cycle(['int','p','sw','c']) - field_type = 'modp' - def __init__(self, *args): - super(inputpersonal, self).__init__(*args) - for i in range(0, len(args), 4): - assert args[i + 2].size == args[i] - assert args[i + 3].size == args[i] +class privateoutput(personal_base, base.DataInstruction): + """ Private output to cint. + + :param: vector size (int) + :param: player (int) + :param: destination (cint) + :param: source (sint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['PRIVATEOUTPUT'] + arg_format = tools.cycle(['int','p','cw','s']) + data_type = 'open' def add_usage(self, req_node): - for i in range(0, len(self.args), 4): - player = self.args[i + 1] - req_node.increment((self.field_type, 'input', player), \ - self.args[i]) + personal_base.add_usage(self, req_node) + base.DataInstruction.add_usage(self, req_node) + + def get_repeat(self): + return sum(self.args[::4]) + +class sendpersonal(base.Instruction, base.Mergeable): + """ Private input from cint. + + :param: vector size (int) + :param: destination player (int) + :param: destination (cint) + :param: source player (int) + :param: source (cint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['SENDPERSONAL'] + arg_format = tools.cycle(['int','p','cw','p','c']) + + def __init__(self, *args): + super(sendpersonal, self).__init__(*args) + for i in range(0, len(args), 5): + assert args[i + 2].size == args[i] + assert args[i + 4].size == args[i] @base.gf2n @base.vectorize @@ -1494,6 +1593,14 @@ class cond_print_plain(base.IOInstruction): code = base.opcodes['CONDPRINTPLAIN'] arg_format = ['c', 'c', 'c'] + def __init__(self, *args, **kwargs): + base.Instruction.__init__(self, *args, **kwargs) + self.size = args[1].size + args[2].set_size(self.size) + + def get_code(self): + return base.Instruction.get_code(self, self.size) + class print_int(base.IOInstruction): """ Output clear integer register. @@ -1536,7 +1643,7 @@ class print_char(base.IOInstruction): arg_format = ['int'] def __init__(self, ch): - super(print_char, self).__init__(ord(ch)) + super(print_char, self).__init__(ch) class print_char4(base.IOInstruction): """ Output four bytes. @@ -1589,7 +1696,16 @@ def has_var_args(self): return True class readsockets(base.IOInstruction): - """Read a variable number of secret shares + MACs from socket for a client id and store in registers""" + """ Read a variable number of secret shares (potentially with MAC) + from a socket for a client id and store them in registers. If the + protocol uses MACs, the client should be different for every party. + + :param: client id (regint) + :param: vector size (int) + :param: source (sint) + :param: (repeat source)... + + """ __slots__ = [] code = base.opcodes['READSOCKETS'] arg_format = tools.chain(['ci','int'], itertools.repeat('sw')) @@ -1626,6 +1742,26 @@ class writesocketc(base.IOInstruction): def has_var_args(self): return True +class writesockets(base.IOInstruction): + """ Write a variable number of secret shares (potentially with MAC) + from registers into a socket for a specified client id. If the + protocol uses MACs, the client should be different for every party. + + :param: number of arguments to follow + :param: client id (regint) + :param: message type (must be 0) + :param: vector size (int) + :param: source (sint) + :param: (repeat source)... + + """ + __slots__ = [] + code = base.opcodes['WRITESOCKETS'] + arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s')) + + def has_var_args(self): + return True + class writesocketshare(base.IOInstruction): """ Write a variable number of shares (without MACs) from secret registers into socket for a specified client id. @@ -1689,14 +1825,15 @@ class writesharestofile(base.IOInstruction): """ Write shares to ``Persistence/Transactions-P.data`` (appending at the end). - :param: number of shares (int) + :param: number of arguments to follow / number of shares plus one (int) + :param: position (regint, -1 for appending) :param: source (sint) :param: (repeat from source)... """ __slots__ = [] code = base.opcodes['WRITEFILESHARE'] - arg_format = itertools.repeat('s') + arg_format = tools.chain(['ci'], itertools.repeat('s')) def has_var_args(self): return True @@ -1750,26 +1887,19 @@ class floatoutput(base.PublicFileIOInstruction): code = base.opcodes['FLOATOUTPUT'] arg_format = ['p','c','c','c','c'] -@base.gf2n @base.vectorize -class startprivateoutput(base.Instruction): - r""" Initiate private output to $n$ of $s_j$ via $s_i$. """ - __slots__ = [] - code = base.opcodes['STARTPRIVATEOUTPUT'] - arg_format = ['sw','s','p'] - field_type = 'modp' +class fixinput(base.PublicFileIOInstruction): + """ Binary fixed-point input. - def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[2]), \ - self.get_size()) + :param: player (int) + :param: destination (cint) + :param: exponent (int) + :param: input type (0: 64-bit integer, 1: float, 2: double) -@base.gf2n -@base.vectorize -class stopprivateoutput(base.Instruction): - r""" Previously iniated private output to $n$ via $c_i$. """ + """ __slots__ = [] - code = base.opcodes['STOPPRIVATEOUTPUT'] - arg_format = ['cw','c','p'] + code = base.opcodes['FIXINPUT'] + arg_format = ['p','cw','int','int'] @base.vectorize class rand(base.Instruction): @@ -2084,17 +2214,26 @@ class gconvgf2n(base.Instruction): # rename 'open' to avoid conflict with built-in open function @base.gf2n @base.vectorize -class asm_open(base.VarArgsInstruction): +class asm_open(base.VarArgsInstruction, base.DataInstruction): """ Reveal secret registers (vectors) to clear registers (vectors). - :param: number of argument to follow (multiple of two) + :param: number of argument to follow (odd number) + :param: check after opening (0/1) :param: destination (cint) :param: source (sint) :param: (repeat the last two)... """ __slots__ = [] code = base.opcodes['OPEN'] - arg_format = tools.cycle(['cw','s']) + arg_format = tools.chain(['int'], tools.cycle(['cw','s'])) + data_type = 'open' + + def get_repeat(self): + return (len(self.args) - 1) // 2 + + def merge(self, other): + self.args[0] |= other.args[0] + self.args += other.args[1:] @base.gf2n @base.vectorize @@ -2171,7 +2310,8 @@ def get_used(self): @base.gf2n @base.vectorize -class dotprods(base.VarArgsInstruction, base.DataInstruction): +class dotprods(base.VarArgsInstruction, base.DataInstruction, + base.DynFormatInstruction): """ Dot product of secret registers (vectors). Note that the vectorized version works element-wise. @@ -2199,31 +2339,30 @@ def __init__(self, *args): flat_args += [x, y] base.Instruction.__init__(self, *flat_args) - @property - def arg_format(self): + @classmethod + def dynamic_arg_format(self, args): field = 'g' if self.is_gf2n() else '' - for i in self.bases(): - yield 'int' + yield 'int' + for i, n in self.bases(args): yield 's' + field + 'w' - for j in range(self.args[i] - 2): + assert n > 2 + for j in range(n - 2): yield 's' + field + yield 'int' - gf2n_arg_format = arg_format - - def bases(self): - i = 0 - while i < len(self.args): - yield i - i += self.args[i] + @property + def gf2n_arg_format(self): + return self.arg_format() def get_repeat(self): - return sum(self.args[i] // 2 for i in self.bases()) * self.get_size() + return sum(self.args[i] // 2 + for i, n in self.bases(iter(self.args))) * self.get_size() def get_def(self): - return [self.args[i + 1] for i in self.bases()] + return [self.args[i + 1] for i, n in self.bases(iter(self.args))] def get_used(self): - for i in self.bases(): + for i, n in self.bases(iter(self.args)): for reg in self.args[i + 2:i + self.args[i]]: yield reg @@ -2334,6 +2473,124 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) +class shuffle_base(base.DataInstruction): + n_relevant_parties = 2 + + @staticmethod + def logn(n): + return int(math.ceil(math.log(n, 2))) + + @classmethod + def n_swaps(cls, n): + logn = cls.logn(n) + return logn * 2 ** logn - 2 ** logn + 1 + + def add_gen_usage(self, req_node, n): + # hack for unknown usage + req_node.increment(('bit', 'inverse'), float('inf')) + # minimal usage with two relevant parties + logn = self.logn(n) + n_switches = self.n_swaps(n) + for i in range(self.n_relevant_parties): + req_node.increment((self.field_type, 'input', i), n_switches) + # multiplications for bit check + req_node.increment((self.field_type, 'triple'), + n_switches * self.n_relevant_parties) + + def add_apply_usage(self, req_node, n, record_size): + req_node.increment(('bit', 'inverse'), float('inf')) + logn = self.logn(n) + n_switches = self.n_swaps(n) * self.n_relevant_parties + if n != 2 ** logn: + record_size += 1 + req_node.increment((self.field_type, 'triple'), + n_switches * record_size) + +@base.gf2n +class secshuffle(base.VectorInstruction, shuffle_base): + """ Secure shuffling. + + :param: destination (sint) + :param: source (sint) + """ + __slots__ = [] + code = base.opcodes['SECSHUFFLE'] + arg_format = ['sw','s','int'] + + def __init__(self, *args, **kwargs): + super(secshuffle_class, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) + +class gensecshuffle(shuffle_base): + """ Generate secure shuffle to bit used several times. + + :param: destination (regint) + :param: size (int) + + """ + __slots__ = [] + code = base.opcodes['GENSECSHUFFLE'] + arg_format = ['ciw','int'] + + def add_usage(self, req_node): + self.add_gen_usage(req_node, self.args[1]) + +class applyshuffle(base.VectorInstruction, shuffle_base): + """ Generate secure shuffle to bit used several times. + + :param: destination (sint) + :param: source (sint) + :param: number of elements to be treated as one (int) + :param: handle (regint) + :param: reverse (0/1) + + """ + __slots__ = [] + code = base.opcodes['APPLYSHUFFLE'] + arg_format = ['sw','s','int','ci','int'] + + def __init__(self, *args, **kwargs): + super(applyshuffle, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) + +class delshuffle(base.Instruction): + """ Delete secure shuffle. + + :param: handle (regint) + + """ + code = base.opcodes['DELSHUFFLE'] + arg_format = ['ci'] + +class inverse_permutation(base.VectorInstruction, shuffle_base): + """ Calculate the inverse permutation of a secret permutation. + + :param: destination (sint) + :param: source (sint) + + """ + __slots__ = [] + code = base.opcodes['INVPERM'] + arg_format = ['sw', 's'] + + def __init__(self, *args, **kwargs): + super(inverse_permutation, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + + def add_usage(self, req_node): + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), 1) + + class check(base.Instruction): """ Force MAC check in current thread and all idle thread if current @@ -2361,7 +2618,7 @@ def expand(self): c = [program.curr_block.new_reg('c') for i in range(2)] square(s[0], s[1]) subs(s[2], self.args[1], s[0]) - asm_open(c[0], s[2]) + asm_open(False, c[0], s[2]) mulc(c[1], c[0], c[0]) mulm(s[3], self.args[1], c[0]) adds(s[4], s[3], s[3]) @@ -2369,18 +2626,76 @@ def expand(self): subml(self.args[0], s[5], c[1]) -@base.gf2n -@base.vectorize -class lts(base.CISC): - """ Secret comparison $s_i = (s_j < s_k)$. """ - __slots__ = [] - arg_format = ['sw', 's', 's', 'int', 'int'] - - def expand(self): - from .types import sint - a = sint() - subs(a, self.args[1], self.args[2]) - comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) +# placeholder for documentation +class cisc: + """ Meta instruction for emulation. This instruction is only generated + when using ``-K`` with ``compile.py``. The header looks as follows: + + :param: number of arguments after name plus one + :param: name (16 bytes, zero-padded) + + Currently, the following names are supported: + + LTZ + Less than zero. + + :param: number of arguments in this unit (must be 6) + :param: vector size + :param: result (sint) + :param: input (sint) + :param: bit length + :param: (ignored) + :param: (repeat)... + + Trunc + Truncation. + + :param: number of arguments in this unit (must be 8) + :param: vector size + :param: result (sint) + :param: input (sint) + :param: bit length + :param: number of bits to truncate + :param: (ignored) + :param: 0 for unsigned or 1 for signed + :param: (repeat)... + + FPDiv + Fixed-point division. Division by zero results in zero without error. + + :param: number of arguments in this unit (must be at least 7) + :param: vector size + :param: result (sint) + :param: dividend (sint) + :param: divisor (sint) + :param: (ignored) + :param: fixed-point precision + :param: (repeat)... + + exp2_fx + Fixed-point power of two. + + :param: number of arguments in this unit (must be at least 6) + :param: vector size + :param: result (sint) + :param: exponent (sint) + :param: (ignored) + :param: fixed-point precision + :param: (repeat)... + + log2_fx + Fixed-point logarithm with base 2. + + :param: number of arguments in this unit (must be at least 6) + :param: vector size + :param: result (sint) + :param: input (sint) + :param: (ignored) + :param: fixed-point precision + :param: (repeat)... + + """ + code = base.opcodes['CISC'] # hack for circular dependency from Compiler import comparison diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 66b55fd9d..b72079c76 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -4,6 +4,8 @@ import inspect import functools import copy +import sys +import struct from Compiler.exceptions import * from Compiler.config import * from Compiler import util @@ -78,6 +80,7 @@ SUBSI = 0x2A, SUBCFI = 0x2B, SUBSFI = 0x2C, + PREFIXSUMS = 0x2D, # Multiplication/division MULC = 0x30, MULM = 0x31, @@ -103,6 +106,13 @@ MATMULSM = 0xAB, CONV2DS = 0xAC, CHECK = 0xAF, + PRIVATEOUTPUT = 0xAD, + # Shuffling + SECSHUFFLE = 0xFA, + GENSECSHUFFLE = 0xFB, + APPLYSHUFFLE = 0xFC, + DELSHUFFLE = 0xFD, + INVPERM = 0xFE, # Data access TRIPLE = 0x50, BIT = 0x51, @@ -126,6 +136,7 @@ INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, INPUTPERSONAL = 0xF5, + SENDPERSONAL = 0xF6, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -198,8 +209,9 @@ CONDPRINTPLAIN = 0xE1, INTOUTPUT = 0xE6, FLOATOUTPUT = 0xE7, - GBITDEC = 0x184, - GBITCOM = 0x185, + FIXINPUT = 0xE8, + GBITDEC = 0x18A, + GBITCOM = 0x18B, # Secure socket INITSECURESOCKET = 0x1BA, RESPSECURESOCKET = 0x1BB @@ -215,8 +227,13 @@ def int_to_bytes(x): global_vector_size_stack = [] global_instruction_type_stack = ['modp'] +def check_vector_size(size): + if isinstance(size, program.curr_tape.Register): + raise CompilerError('vector size must be known at compile time') + def set_global_vector_size(size): stack = global_vector_size_stack + check_vector_size(size) if size == 1 and not stack: return stack.append(size) @@ -299,11 +316,12 @@ def maybe_vectorized_instruction(*args, **kwargs): vectorized_name = 'v' + instruction.__name__ Vectorized_Instruction.__name__ = vectorized_name global_dict[vectorized_name] = Vectorized_Instruction + + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction - instruction.__doc__ = '' - # exclude GF(2^n) instructions from documentation - if instruction.code and instruction.code >> 8 == 1: - maybe_vectorized_instruction.__doc__ = '' + maybe_vectorized_instruction.arg_format = instruction.arg_format return maybe_vectorized_instruction @@ -332,7 +350,7 @@ def reformat(arg_format): if isinstance(arg_format, list): __format = [] for __f in arg_format: - if __f in ('int', 'p', 'ci', 'str'): + if __f in ('int', 'long', 'p', 'ci', 'str'): __format.append(__f) else: __format.append(__f[0] + 'g' + __f[1:]) @@ -355,12 +373,13 @@ class GF2N_Instruction(instruction_cls): arg_format = instruction_cls.gf2n_arg_format elif isinstance(instruction_cls.arg_format, itertools.repeat): __f = next(instruction_cls.arg_format) - if __f != 'int' and __f != 'p': + if __f not in ('int', 'long', 'p'): arg_format = itertools.repeat(__f[0] + 'g' + __f[1:]) else: arg_format = copy.deepcopy(instruction_cls.arg_format) reformat(arg_format) + @classmethod def is_gf2n(self): return True @@ -389,8 +408,11 @@ def maybe_gf2n_instruction(*args, **kwargs): else: global_dict[GF2N_Instruction.__name__] = GF2N_Instruction + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction_cls - instruction_cls.__doc__ = '' + maybe_gf2n_instruction.arg_format = instruction.arg_format return maybe_gf2n_instruction #return instruction @@ -404,6 +426,7 @@ class MergeCISC(Mergeable): def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs + self.security = program.security self.calls = [(args, kwargs)] self.params = [] self.used = [] @@ -427,7 +450,7 @@ def is_vec(self): def merge_id(self): return self.function, tuple(self.params), \ - tuple(sorted(self.kwargs.items())) + tuple(sorted(self.kwargs.items())), self.security def merge(self, other): self.calls += other.calls @@ -452,7 +475,10 @@ def new_instructions(self, size, regs): except: args.append(arg) program.options.cisc = False + old_security = program.security + program.security = self.security self.function(*args, **self.kwargs) + program.security = old_security program.options.cisc = True reset_global_vector_size() program.curr_tape = old_tape @@ -481,7 +507,16 @@ def new_instructions(self, size, regs): def expand_merged(self, skip): if function.__name__ in skip: - return [self], 0 + good = True + for call in self.calls: + if not good: + break + for arg in call[0]: + if isinstance(arg, program.curr_tape.Register) and \ + not issubclass(type(self.calls[0][0][0]), type(arg)): + good = False + if good: + return [self], 0 tape = program.curr_tape block = tape.BasicBlock(tape, None, None) tape.active_basicblock = block @@ -490,8 +525,12 @@ def expand_merged(self, skip): for arg in self.args: try: new_regs.append(type(arg)(size=size)) - except: + except TypeError: break + except: + print([call[0][0].size for call in self.calls]) + raise + assert len(new_regs) > 1 base = 0 for call in self.calls: for new_reg, reg in zip(new_regs[1:], call[0][1:]): @@ -514,12 +553,13 @@ def add_usage(self, *args): def get_bytes(self): assert len(self.kwargs) < 2 - res = int_to_bytes(opcodes['CISC']) + res = LongArgFormat.encode(opcodes['CISC']) res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1) name = self.function.__name__ String.check(name) res += String.encode(name) for call in self.calls: + call[1].pop('nearest', None) assert not call[1] res += int_to_bytes(len(call[0]) + 2) res += int_to_bytes(call[0][0].size) @@ -549,7 +589,7 @@ def wrapper(*args, **kwargs): same_sizes &= arg.size == args[0].size except: pass - if program.options.cisc and same_sizes: + if program.use_cisc() and same_sizes: return MergeCISC(*args, **kwargs) else: return function(*args, **kwargs) @@ -562,9 +602,9 @@ def instruction(res, *args, **kwargs): instruction = cisc(instruction) def wrapper(*args, **kwargs): - if not program.options.cisc: - return function(*args, **kwargs) from Compiler import types + if not (program.options.cisc and isinstance(args[0], types._register)): + return function(*args, **kwargs) if isinstance(args[0], types._clear): res_type = type(args[1]) else: @@ -641,7 +681,8 @@ def check(cls, arg): raise ArgumentError(arg, 'Invalid register argument') if arg.program != program.curr_tape: raise ArgumentError(arg, 'Register from other tape, trace: %s' % \ - util.format_trace(arg.caller)) + util.format_trace(arg.caller) + + '\nMaybe use MemValue') if arg.reg_type != cls.reg_type: raise ArgumentError(arg, "Wrong register type '%s', expected '%s'" % \ (arg.reg_type, cls.reg_type)) @@ -651,6 +692,12 @@ def encode(cls, arg): assert arg.i >= 0 return int_to_bytes(arg.i) + def __init__(self, f): + self.i = struct.unpack('>I', f.read(4))[0] + + def __str__(self): + return self.reg_type + str(self.i) + class ClearModpAF(RegisterArgFormat): reg_type = RegType.ClearModp @@ -667,21 +714,41 @@ class ClearIntAF(RegisterArgFormat): reg_type = RegType.ClearInt class IntArgFormat(ArgFormat): + n_bits = 32 + @classmethod def check(cls, arg): - if not isinstance(arg, int) and not arg is None: - raise ArgumentError(arg, 'Expected an integer-valued argument') + if not arg is None: + if not isinstance(arg, int): + raise ArgumentError(arg, 'Expected an integer-valued argument') + if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits: + raise ArgumentError( + arg, 'Immediate value outside of %d-bit range' % cls.n_bits) @classmethod def encode(cls, arg): return int_to_bytes(arg) + def __init__(self, f): + self.i = struct.unpack('>i', f.read(4))[0] + + def __str__(self): + return str(self.i) + +class LongArgFormat(IntArgFormat): + n_bits = 64 + + @classmethod + def encode(cls, arg): + return list(struct.pack('>q', arg)) + + def __init__(self, f): + self.i = struct.unpack('>q', f.read(8))[0] + class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): super(ImmediateModpAF, cls).check(arg) - if arg >= 2**32 or arg < -2**32: - raise ArgumentError(arg, 'Immediate value outside of 32-bit range') class ImmediateGF2NAF(IntArgFormat): @classmethod @@ -692,6 +759,8 @@ def check(cls, arg): class PlayerNoAF(IntArgFormat): @classmethod def check(cls, arg): + if not util.is_constant(arg): + raise CompilerError('Player number must be known at compile time') super(PlayerNoAF, cls).check(arg) if arg > 256: raise ArgumentError(arg, 'Player number > 256') @@ -712,6 +781,13 @@ def check(cls, arg): def encode(cls, arg): return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg)) + def __init__(self, f): + tmp = f.read(16) + self.str = str(tmp[0:tmp.find(b'\0')], 'ascii') + + def __str__(self): + return self.str + ArgFormats = { 'c': ClearModpAF, 's': SecretModpAF, @@ -726,6 +802,7 @@ def encode(cls, arg): 'i': ImmediateModpAF, 'ig': ImmediateGF2NAF, 'int': IntArgFormat, + 'long': LongArgFormat, 'p': PlayerNoAF, 'str': String, } @@ -766,7 +843,7 @@ def get_code(self, prefix=0): return (prefix << self.code_length) + self.code def get_encoding(self): - enc = int_to_bytes(self.get_code()) + enc = LongArgFormat.encode(self.get_code()) # add the number of registers if instruction flagged as has var args if self.has_var_args(): enc += int_to_bytes(len(self.args)) @@ -819,6 +896,7 @@ def has_var_args(self): def is_vec(self): return False + @classmethod def is_gf2n(self): return False @@ -867,6 +945,10 @@ def get_new_args(self, size, subs): new_args.append(arg) return new_args + @staticmethod + def get_usage(args): + return {} + # String version of instruction attempting to replicate encoded version def __str__(self): @@ -880,6 +962,66 @@ def __str__(self): def __repr__(self): return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' +class ParsedInstruction: + reverse_opcodes = {} + + def __init__(self, f): + cls = type(self) + from Compiler import instructions + from Compiler.GC import instructions as gc_inst + if not cls.reverse_opcodes: + for module in instructions, gc_inst: + for x, y in inspect.getmodule(module).__dict__.items(): + if inspect.isclass(y) and y.__name__[0] != 'v': + try: + cls.reverse_opcodes[y.code] = y + except AttributeError: + pass + read = lambda: struct.unpack('>I', f.read(4))[0] + full_code = struct.unpack('>Q', f.read(8))[0] + code = full_code % (1 << Instruction.code_length) + self.size = full_code >> Instruction.code_length + self.type = cls.reverse_opcodes[code] + t = self.type + name = t.__name__ + try: + n_args = len(t.arg_format) + self.var_args = False + except: + n_args = read() + self.var_args = True + try: + arg_format = iter(t.arg_format) + except: + if name == 'cisc': + arg_format = itertools.chain(['str'], itertools.repeat('int')) + else: + def arg_iter(): + i = 0 + while True: + try: + yield self.args[i].i + except AttributeError: + yield None + i += 1 + arg_format = t.dynamic_arg_format(arg_iter()) + self.args = [] + for i in range(n_args): + self.args.append(ArgFormats[next(arg_format)](f)) + + def __str__(self): + name = self.type.__name__ + res = name + ' ' + if self.size > 1: + res = 'v' + res + str(self.size) + ', ' + if self.var_args: + res += str(len(self.args)) + ', ' + res += ', '.join(str(arg) for arg in self.args) + return res + + def get_usage(self): + return self.type.get_usage(self.args) + class VarArgsInstruction(Instruction): def has_var_args(self): return True @@ -891,6 +1033,26 @@ class VectorInstruction(Instruction): def get_code(self): return super(VectorInstruction, self).get_code(len(self.args[0])) +class DynFormatInstruction(Instruction): + __slots__ = [] + + @property + def arg_format(self): + return self.dynamic_arg_format(iter(self.args)) + + @classmethod + def bases(self, args): + i = 0 + while True: + try: + n = next(args) + except StopIteration: + return + yield i, n + i += n + for j in range(n - 1): + next(args) + ### ### Basic arithmetic ### @@ -924,21 +1086,27 @@ class ClearImmediate(ImmediateBase): ### Memory access instructions ### -class DirectMemoryInstruction(Instruction): +class MemoryInstruction(Instruction): + __slots__ = ['_protect'] + def __init__(self, *args, **kwargs): + super(MemoryInstruction, self).__init__(*args, **kwargs) + self._protect = program._protect_memory + +class DirectMemoryInstruction(MemoryInstruction): __slots__ = [] def __init__(self, *args, **kwargs): super(DirectMemoryInstruction, self).__init__(*args, **kwargs) -class IndirectMemoryInstruction(Instruction): +class IndirectMemoryInstruction(MemoryInstruction): __slots__ = [] def get_direct(self, address): return self.direct(self.args[0], address, add_to_prog=False) -class ReadMemoryInstruction(Instruction): +class ReadMemoryInstruction(MemoryInstruction): __slots__ = [] -class WriteMemoryInstruction(Instruction): +class WriteMemoryInstruction(MemoryInstruction): __slots__ = [] class DirectMemoryWriteInstruction(DirectMemoryInstruction, \ @@ -965,12 +1133,16 @@ class IOInstruction(DoNotEliminateInstruction): @classmethod def str_to_int(cls, s): """ Convert a 4 character string to an integer. """ + try: + s = bytearray(s, 'utf8') + except: + pass if len(s) > 4: raise CompilerError('String longer than 4 characters') n = 0 for c in reversed(s.ljust(4)): n <<= 8 - n += ord(c) + n += c return n class AsymmetricCommunicationInstruction(DoNotEliminateInstruction): @@ -989,6 +1161,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction): """ Input from text file or stdin """ __slots__ = [] + def add_usage(self, req_node): + for player in self.get_players(): + req_node.increment((self.field_type, 'input', player), \ + self.get_size()) + ### ### Data access instructions ### diff --git a/Compiler/library.py b/Compiler/library.py index 529608dc2..3f6f2d1c4 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -12,6 +12,7 @@ import random import collections import operator +import copy from functools import reduce def get_program(): @@ -63,6 +64,7 @@ def print_str(s, *args): variables/registers with ``%s``. """ def print_plain_str(ss): """ Print a plain string (no custom formatting options) """ + ss = bytearray(ss, 'utf8') i = 1 while 4*i <= len(ss): print_char4(ss[4*(i-1):4*i]) @@ -115,7 +117,12 @@ def print_ln(s='', *args): print_ln('a is %s.', a.reveal()) """ - print_str(s + '\n', *args) + print_str(str(s) + '\n', *args) + +def print_both(s, end='\n'): + """ Print line during compilation and execution. """ + print(s, end=end) + print_str(s + end) def print_ln_if(cond, ss, *args): """ Print line if :py:obj:`cond` is true. The further arguments @@ -137,7 +144,7 @@ def print_str_if(cond, ss, *args): """ Print string conditionally. See :py:func:`print_ln_if` for details. """ if util.is_constant(cond): if cond: - print_ln(ss, *args) + print_str(ss, *args) else: subs = ss.split('%s') assert len(subs) == len(args) + 1 @@ -153,7 +160,8 @@ def print_str_if(cond, ss, *args): print_str_if(cond, *_expand_to_print(val)) else: print_str_if(cond, str(val)) - s += '\0' * ((-len(s)) % 4) + s = bytearray(s, 'utf8') + s += b'\0' * ((-len(s)) % 4) while s: cond.print_if(s[:4]) s = s[4:] @@ -219,7 +227,10 @@ def crash(condition=None): :param condition: crash if true (default: true) """ - if condition == None: + if isinstance(condition, localint): + # allow crash on local values + condition = condition._v + if condition is None: condition = regint(1) instructions.crash(regint.conv(condition)) @@ -239,6 +250,10 @@ def store_in_mem(value, address): try: value.store_in_mem(address) except AttributeError: + if isinstance(value, (list, tuple)): + for i, x in enumerate(value): + store_in_mem(x, address + i) + return # legacy if value.is_clear: if isinstance(address, cint): @@ -257,11 +272,13 @@ def reveal(secret): try: return secret.reveal() except AttributeError: + if secret.is_clear: + return secret if secret.is_gf2n: res = cgf2n() else: res = cint() - instructions.asm_open(res, secret) + instructions.asm_open(True, res, secret) return res @vectorize @@ -278,13 +295,13 @@ def get_arg(): ldarg(res) return res -def make_array(l): +def make_array(l, t=None): if isinstance(l, program.Tape.Register): - res = Array(1, type(l)) - res[0] = l + res = Array(len(l), t or type(l)) + res[:] = l else: l = list(l) - res = Array(len(l), type(l[0]) if l else cint) + res = Array(len(l), t or type(l[0]) if l else cint) res.assign(l) return res @@ -456,6 +473,10 @@ def wrapper(self, *args): return wrapper def cond_swap(x,y): + from .types import SubMultiArray + if isinstance(x, (Array, SubMultiArray)): + b = x[0] > y[0] + return list(zip(*[b.cond_swap(xx, yy) for xx, yy in zip(x, y)])) b = x < y if isinstance(x, sfloat): res = ([], []) @@ -467,11 +488,11 @@ def cond_swap(x,y): res[0].append(bx + yy - by) res[1].append(xx - bx + by) return sfloat(*res[0]), sfloat(*res[1]) - bx = b * x - by = b * y - return bx + y - by, x - bx + by + return b.cond_swap(y, x) def sort(a): + print("WARNING: you're using bubble sort") + res = a for i in range(len(a)): @@ -497,282 +518,36 @@ def odd_even_merge_sort(a): if len(a) == 1: return elif len(a) % 2 == 0: + aa = a + a = list(a) lower = a[:len(a)//2] upper = a[len(a)//2:] odd_even_merge_sort(lower) odd_even_merge_sort(upper) a[:] = lower + upper odd_even_merge(a) + aa[:] = a else: raise CompilerError('Length of list must be power of two') def chunky_odd_even_merge_sort(a): - tmp = a[0].Array(len(a)) - for i,j in enumerate(a): - tmp[i] = j - l = 1 - while l < len(a): - l *= 2 - k = 1 - while k < l: - k *= 2 - def round(): - for i in range(len(a)): - a[i] = tmp[i] - for i in range(len(a) // l): - for j in range(l // k): - base = i * l + j - step = l // k - if k == 2: - a[base], a[base+step] = cond_swap(a[base], a[base+step]) - else: - b = a[base:base+k*step:step] - for m in range(base + step, base + (k - 1) * step, 2 * step): - a[m], a[m+step] = cond_swap(a[m], a[m+step]) - for i in range(len(a)): - tmp[i] = a[i] - chunk = MPCThread(round, 'sort-%d-%d' % (l,k), single_thread=True) - chunk.start() - chunk.join() - #round() - for i in range(len(a)): - a[i] = tmp[i] + raise CompilerError( + 'This function has been removed, use loopy_odd_even_merge_sort instead') def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use_chunk_wraps=False): - if n is None: - n = len(a) - a_base = instructions.program.malloc(n, 's') - for i,j in enumerate(a): - store_in_mem(j, a_base + i) - else: - a_base = a - tmp_base = instructions.program.malloc(n, 's') - chunks = {} - threads = [] - - def run_threads(): - for thread in threads: - thread.start() - for thread in threads: - thread.join() - del threads[:] - - def run_chunk(size, base): - if size not in chunks: - def swap_list(list_base): - for i in range(size // 2): - base = list_base + 2 * i - x, y = cond_swap(sint.load_mem(base), - sint.load_mem(base + 1)) - store_in_mem(x, base) - store_in_mem(y, base + 1) - chunks[size] = FunctionTape(swap_list, 'sort-%d' % size) - return chunks[size](base) - - def run_round(size): - # minimize number of chunk sizes - n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) - lower_size = size // n_chunks // 2 * 2 - n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2 - # print len(to_swap) == lower_size * n_lower_size + \ - # (lower_size + 2) * (n_chunks - n_lower_size), \ - # len(to_swap), n_chunks, lower_size, n_lower_size - base = 0 - round_threads = [] - for i in range(n_lower_size): - round_threads.append(run_chunk(lower_size, tmp_base + base)) - base += lower_size - for i in range(n_chunks - n_lower_size): - round_threads.append(run_chunk(lower_size + 2, tmp_base + base)) - base += lower_size + 2 - run_threads_in_rounds(round_threads) - - postproc_chunks = [] - wrap_chunks = {} - post_threads = [] - pre_threads = [] - - def load_and_store(x, y, to_right): - if to_right: - store_in_mem(sint.load_mem(x), y) - else: - store_in_mem(sint.load_mem(y), x) - - def run_setup(k, a_addr, step, tmp_addr): - if k == 2: - def mem_op(preproc, a_addr, step, tmp_addr): - load_and_store(a_addr, tmp_addr, preproc) - load_and_store(a_addr + step, tmp_addr + 1, preproc) - res = 2 - else: - def mem_op(preproc, a_addr, step, tmp_addr): - instructions.program.curr_tape.merge_opens = False -# for i,m in enumerate(range(a_addr + step, a_addr + (k - 1) * step, step)): - for i in range(k - 2): - m = a_addr + step + i * step - load_and_store(m, tmp_addr + i, preproc) - res = k - 2 - if not use_chunk_wraps or k <= 4: - mem_op(True, a_addr, step, tmp_addr) - postproc_chunks.append((mem_op, (a_addr, step, tmp_addr))) - else: - if k not in wrap_chunks: - pre_chunk = FunctionTape(mem_op, 'pre-%d' % k, - compile_args=[True]) - post_chunk = FunctionTape(mem_op, 'post-%d' % k, - compile_args=[False]) - wrap_chunks[k] = (pre_chunk, post_chunk) - pre_chunk, post_chunk = wrap_chunks[k] - pre_threads.append(pre_chunk(a_addr, step, tmp_addr)) - post_threads.append(post_chunk(a_addr, step, tmp_addr)) - return res - - def run_threads_in_rounds(all_threads): - for thread in all_threads: - if len(threads) == n_threads: - run_threads() - threads.append(thread) - run_threads() - del all_threads[:] - - def run_postproc(): - run_threads_in_rounds(post_threads) - for chunk,args in postproc_chunks: - chunk(False, *args) - postproc_chunks[:] = [] - - l = 1 - while l < n: - l *= 2 - k = 1 - while k < l: - k *= 2 - size = 0 - instructions.program.curr_tape.merge_opens = False - for i in range(n // l): - for j in range(l // k): - base = i * l + j - step = l // k - size += run_setup(k, a_base + base, step, tmp_base + size) - run_threads_in_rounds(pre_threads) - run_round(size) - run_postproc() - - if isinstance(a, list): - for i in range(n): - a[i] = sint.load_mem(a_base + i) - instructions.program.free(a_base, 's') - instructions.program.free(tmp_base, 's') + raise CompilerError( + 'This function has been removed, use loopy_odd_even_merge_sort instead') def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7): - if n is None: - n = len(a) - a_base = instructions.program.malloc(n, 's') - for i,j in enumerate(a): - store_in_mem(j, a_base + i) - else: - a_base = a - tmp_base = instructions.program.malloc(n, 's') - tmp_i = instructions.program.malloc(1, 'ci') - chunks = {} - threads = [] - - def run_threads(): - for thread in threads: - thread.start() - for thread in threads: - thread.join() - del threads[:] - - def run_threads_in_rounds(all_threads): - for thread in all_threads: - if len(threads) == n_threads: - run_threads() - threads.append(thread) - run_threads() - del all_threads[:] - - def run_chunk(size, base): - if size not in chunks: - def swap_list(list_base): - for i in range(size // 2): - base = list_base + 2 * i - x, y = cond_swap(sint.load_mem(base), - sint.load_mem(base + 1)) - store_in_mem(x, base) - store_in_mem(y, base + 1) - chunks[size] = FunctionTape(swap_list, 'sort-%d' % size) - return chunks[size](base) - - def run_round(size): - # minimize number of chunk sizes - n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) - lower_size = size // n_chunks // 2 * 2 - n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2 - # print len(to_swap) == lower_size * n_lower_size + \ - # (lower_size + 2) * (n_chunks - n_lower_size), \ - # len(to_swap), n_chunks, lower_size, n_lower_size - base = 0 - round_threads = [] - for i in range(n_lower_size): - round_threads.append(run_chunk(lower_size, tmp_base + base)) - base += lower_size - for i in range(n_chunks - n_lower_size): - round_threads.append(run_chunk(lower_size + 2, tmp_base + base)) - base += lower_size + 2 - run_threads_in_rounds(round_threads) - - l = 1 - while l < n: - l *= 2 - k = 1 - while k < l: - k *= 2 - def load_and_store(x, y): - if to_tmp: - store_in_mem(sint.load_mem(x), y) - else: - store_in_mem(sint.load_mem(y), x) - def outer(i): - def inner(j): - base = j + a_base + i * l - step = l // k - if k == 2: - tmp_addr = regint.load_mem(tmp_i) - load_and_store(base, tmp_addr) - load_and_store(base + step, tmp_addr + 1) - store_in_mem(tmp_addr + 2, tmp_i) - else: - def inner2(m): - m += base - tmp_addr = regint.load_mem(tmp_i) - load_and_store(m, tmp_addr) - store_in_mem(tmp_addr + 1, tmp_i) - range_loop(inner2, step, (k - 1) * step, step) - range_loop(inner, l // k) - instructions.program.curr_tape.merge_opens = False - to_tmp = True - store_in_mem(tmp_base, tmp_i) - range_loop(outer, n // l) - if k == 2: - run_round(n) - else: - run_round(n // k * (k - 2)) - instructions.program.curr_tape.merge_opens = False - to_tmp = False - store_in_mem(tmp_base, tmp_i) - range_loop(outer, n // l) - - if isinstance(a, list): - for i in range(n): - a[i] = sint.load_mem(a_base + i) - instructions.program.free(a_base, 's') - instructions.program.free(tmp_base, 's') - instructions.program.free(tmp_i, 'ci') + raise CompilerError( + 'This function has been removed, use loopy_odd_even_merge_sort instead') def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32, n_threads=None): + a_in = a + if isinstance(a_in, list): + a = Array.create_from(a) steps = {} l = sorted_length while l < len(a): @@ -816,8 +591,14 @@ def f(i): swap(m2, step) steps[key] = step steps[key](l) + if isinstance(a_in, list): + a_in[:] = list(a) def mergesort(A): + if not get_program().options.insecure: + raise CompilerError('mergesort reveals the order of elements, ' + 'use --insecure to activate it') + B = Array(len(A), sint) def merge(i_left, i_right, i_end): @@ -845,12 +626,18 @@ def merge_loop(i): width.imul(2) return width < len(A) -def range_loop(loop_body, start, stop=None, step=None): +def _range_prep(start, stop, step): if stop is None: stop = start start = 0 if step is None: step = 1 + if util.is_zero(step): + raise CompilerError('step must not be zero') + return start, stop, step + +def range_loop(loop_body, start, stop=None, step=None): + start, stop, step = _range_prep(start, stop, step) def loop_fn(i): res = loop_body(i) return util.if_else(res == 0, stop, i + step) @@ -859,8 +646,6 @@ def loop_fn(i): condition = lambda x: x < stop elif step < 0: condition = lambda x: x > stop - else: - raise CompilerError('step must not be zero') else: b = step > 0 condition = lambda x: b * (x < stop) + (1 - b) * (x > stop) @@ -870,36 +655,34 @@ def loop_fn(i): # known loop count if condition(start): get_tape().req_node.children[-1].aggregator = \ - lambda x: ((stop - start) // step) * x[0] + lambda x: int(ceil(((stop - start) / step))) * x[0] def for_range(start, stop=None, step=None): """ Decorator to execute loop bodies consecutively. Arguments work as - in Python :py:func:`range`, but they can by any public + in Python :py:func:`range`, but they can be any public integer. Information has to be passed out via container types such - as :py:class:`~Compiler.types.Array` or declaring registers as - :py:obj:`global`. Note that changing Python data structures such + as :py:class:`~Compiler.types.Array` or using :py:func:`update`. + Note that changing Python data structures such as lists within the loop is not possible, but the compiler cannot warn about this. :param start/stop/step: regint/cint/int - Example: - - .. code:: + The following should output 10:: + n = 10 a = sint.Array(n) x = sint(0) @for_range(n) def _(i): a[i] = i - global x - x += 1 + x.update(x + 1) + print_ln('%s', x.reveal()) Note that you cannot overwrite data structures such as - :py:class:`~Compiler.types.Array` in a loop even when using - :py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign` - instead. + :py:class:`~Compiler.types.Array` in a loop. Use + :py:func:`~Compiler.types.Array.assign` instead. """ def decorator(loop_body): range_loop(loop_body, start, stop, step) @@ -909,11 +692,13 @@ def decorator(loop_body): def for_range_parallel(n_parallel, n_loops): """ Decorator to execute a loop :py:obj:`n_loops` up to - :py:obj:`n_parallel` loop bodies in parallel. + :py:obj:`n_parallel` loop bodies with optimized communication in a + single thread. + In most cases, it is easier to use :py:func:`for_range_opt`. Using any other control flow instruction inside the loop breaks the optimization. - :param n_parallel: compile-time (int) + :param n_parallel: optimization parameter (int) :param n_loops: regint/cint/int or list of int Example: @@ -937,7 +722,7 @@ def f(i, j): return for_range_multithread(None, n_parallel, n_loops) return map_reduce_single(n_parallel, n_loops) -def for_range_opt(n_loops, budget=None): +def for_range_opt(start, stop=None, step=None, budget=None): """ Execute loop bodies in parallel up to an optimization budget. This prevents excessive loop unrolling. The budget is respected even with nested loops. Note that the optimization is rather @@ -947,8 +732,10 @@ def for_range_opt(n_loops, budget=None): :py:func:`for_range_opt` (e.g, :py:func:`for_range`) breaks the optimization. - :param n_loops: int/regint/cint - :param budget: number of instructions after which to start optimization (default is 100,000) + :param start/stop/step: int/regint/cint (used as in :py:func:`range`) + or :py:obj:`start` only as list/tuple of int (see below) + :param budget: number of instructions after which to start optimization + (default is 100,000) Example: @@ -968,6 +755,15 @@ def _(i): def f(i, j): ... """ + if stop is not None: + start, stop, step = _range_prep(start, stop, step) + def wrapper(loop_body): + n_loops = (step - 1 + stop - start) // step + @for_range_opt(n_loops, budget=budget) + def _(i): + return loop_body(start + i * step) + return wrapper + n_loops = start if isinstance(n_loops, (list, tuple)): return for_range_opt_multithread(None, n_loops) return map_reduce_single(None, n_loops, budget=budget) @@ -1009,9 +805,11 @@ def write_state_to_memory(r): def f(i): state = tuplify(initializer()) start_block = get_block() + j = i * n_parallel + one = regint(1) for k in range(n_parallel): - j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) + j += one if n_parallel > 1 and start_block != get_block(): print('WARNING: parallelization broken ' 'by control flow instruction') @@ -1028,12 +826,15 @@ def _(i): state = tuplify(initializer()) k = 0 block = get_block() + assert not isinstance(n_loops, int) or n_loops > 0 + pre = copy.copy(loop_body.__globals__) while (not util.is_constant(n_loops) or k < n_loops) \ and (len(get_block()) < budget or k == 0) \ and block is get_block(): j = i + k state = reducer(tuplify(loop_body(j)), state) k += 1 + _link(pre, loop_body.__globals__) r = reducer(mem_state, state) write_state_to_memory(r) global n_opt_loops @@ -1064,7 +865,7 @@ def exit_elimination(block): del blocks[-n_to_merge + 1:] del get_tape().req_node.children[-1] merged.children = [] - RegintOptimizer().run(merged.instructions) + RegintOptimizer().run(merged.instructions, get_program()) get_tape().active_basicblock = merged else: req_node = get_tape().req_node.children[-1].nodes[0] @@ -1131,6 +932,15 @@ def _(i): @for_range_opt_multithread(2, [5, 3]) def f(i, j): ... + + Note that you cannot use registers across threads. Use + :py:class:`MemValue` instead:: + + a = MemValue(sint(0)) + @for_range_opt_multithread(8, 80) + def _(i): + b = a + 1 + """ return for_range_multithread(n_threads, None, n_loops) @@ -1142,6 +952,7 @@ def multithread(n_threads, n_items=None, max_size=None): :param n_threads: compile-time (int) :param n_items: regint/cint/int (default: :py:obj:`n_threads`) + :param max_size: maximum size to be processed at once (default: no limit) The following executes ``f(0, 8)``, ``f(8, 8)``, and ``f(16, 9)`` in three different threads: @@ -1158,6 +969,7 @@ def f(base, size): return map_reduce(n_threads, None, n_items, initializer=lambda: [], reducer=None, looping=False) else: + max_size = max(1, max_size) def wrapper(function): @multithread(n_threads, n_items) def new_function(base, size): @@ -1205,7 +1017,13 @@ def decorator(loop_body): if t != regint: raise CompilerError('Not implemented for other than regint') args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') - state = tuple(initializer()) + state = initializer() + if len(state) == 0: + state_type = cint + elif isinstance(state, (tuple, list)): + state_type = type(state[0]) + else: + state_type = type(state) def f(inc): base = args[get_arg()][0] if not util.is_constant(thread_rounds): @@ -1218,8 +1036,7 @@ def f(inc): if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint, \ args[get_arg()].address + 2) - mem_state = Array(len(state), type(state[0]) \ - if state else cint, args[get_arg()][1]) + mem_state = Array(len(state), state_type, args[get_arg()][1]) @map_reduce_single(n_parallel, thread_rounds + inc, \ initializer, reducer, mem_state) def f(i): @@ -1251,14 +1068,14 @@ def f(i): threads = prog.run_tapes(thread_args) for thread in threads: prog.join_tape(thread) - if state: + if len(state): if thread_rounds: for i in range(n_threads - remainder): - state = reducer(Array(len(state), type(state[0]), \ + state = reducer(Array(len(state), state_type, \ args[remainder + i][1]), state) if remainder: for i in range(remainder): - state = reducer(Array(len(state), type(state[0]).reg_type, \ + state = reducer(Array(len(state), state_type, \ args[i][1]), state) def returner(): return untuplify(state) @@ -1294,7 +1111,50 @@ def summer(i): """ return map_sum(n_threads, None, n_loops, len(types), types) +def map_sum_simple(n_threads, n_loops, type, size): + """ Vectorized multi-threaded sum reduction. The following computes a + 100 sums of ten squares in three threads:: + + @map_sum_simple(3, 10, sint, 100) + def summer(i): + return sint(regint.inc(100, i, 0)) ** 2 + + result = summer() + + :param n_threads: number of threads (int) + :param n_loops: number of loop runs (regint/cint/int) + :param type: return type, must match the return statement + in the loop + :param size: vector size, must match the return statement + in the loop + + """ + initializer = lambda: type(0, size=size) + def summer(*args): + assert len(args) == 2 + args = list(args) + for i in (0, 1): + if isinstance(args[i], tuple): + assert len(args[i]) == 1 + args[i] = args[i][0] + for i in (0, 1): + assert len(args[i]) == size + if isinstance(args[i], Array): + args[i] = args[i][:] + return args[0] + args[1] + return map_reduce(n_threads, 1, n_loops, initializer, summer) + def tree_reduce_multithread(n_threads, function, vector): + """ Round-efficient reduction in several threads. The following code + computes the maximum of an array in 10 threads:: + + tree_reduce_multithread(10, lambda x, y: x.max(y), a) + + :param n_threads: number of threads (int) + :param function: reduction function taking exactly two arguments + :param vector: register vector or array + + """ inputs = vector.Array(len(vector)) inputs.assign_vector(vector) outputs = vector.Array(len(vector) // 2) @@ -1311,6 +1171,18 @@ def _(base, size): left = (left + 1) // 2 return inputs[0] +def tree_reduce(function, sequence): + """ Round-efficient reduction. The following computes the maximum + of the list :py:obj:`l`:: + + m = tree_reduce(lambda x, y: x.max(y), l) + + :param function: reduction function taking two arguments + :param sequence: list, vector, or array + + """ + return util.tree_reduce(function, sequence) + def foreach_enumerate(a): """ Run-time loop over public data. This uses ``Player-Data/Public-Input/``. Example: @@ -1338,61 +1210,57 @@ def f(i): return f return decorator -def while_loop(loop_body, condition, arg, g=None): +def while_loop(loop_body, condition, arg=None, g=None): if not callable(condition): raise CompilerError('Condition must be callable') - # store arg in stack - pre_condition = condition(arg) - if not isinstance(pre_condition, (bool,int)) or pre_condition: + if arg is None: + pre_condition = condition() + def loop_fn(): + loop_body() + return condition() + else: + pre_condition = condition(arg) arg = regint(arg) def loop_fn(): result = loop_body(arg) + if isinstance(result, MemValue): + result = result.read() result.link(arg) - cont = condition(result) - return cont + return condition(result) + if not isinstance(pre_condition, (bool,int)) or pre_condition: if_statement(pre_condition, lambda: do_while(loop_fn, g=g)) def while_do(condition, *args): - """ While-do loop. The decorator requires an initialization, and - the loop body function must return a suitable input for - :py:obj:`condition`. + """ While-do loop. :param condition: function returning public integer (regint/cint/int) - :param args: arguments given to :py:obj:`condition` and loop body The following executes an ten-fold loop: .. code:: - @while_do(lambda x: x < 10, regint(0)) - def f(i): + i = regint(0) + @while_do(lambda: i < 10) + def f(): + ... + i.update(i + 1) ... - return i + 1 + """ def decorator(loop_body): while_loop(loop_body, condition, *args) return loop_body return decorator -def do_loop(condition, loop_fn): - # store initial condition to stack - pushint(condition if isinstance(condition,regint) else regint(condition)) - def wrapped_loop(): - # save condition to stack - new_cond = regint.pop() - # run the loop - condition = loop_fn(new_cond) - pushint(condition) - return condition - do_while(wrapped_loop) - regint.pop() - def _run_and_link(function, g=None): if g is None: g = function.__globals__ - import copy pre = copy.copy(g) res = function() + _link(pre, g) + return res + +def _link(pre, g): if g: from .types import _single for name, var in pre.items(): @@ -1402,7 +1270,6 @@ def _run_and_link(function, g=None): raise CompilerError('cannot reassign constants in blocks') if id(new_var) != id(var): new_var.link(var) - return res def do_while(loop_fn, g=None): """ Do-while loop. The loop is stopped if the return value is zero. @@ -1442,11 +1309,17 @@ class State: pass state = State() if callable(condition): condition = condition() + try: + if not condition.is_clear: + raise CompilerError('cannot branch on secret values') + except AttributeError: + pass state.condition = regint.conv(condition) state.start_block = instructions.program.curr_block state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ name='if-block') state.has_else = False + state.caller = [frame[1:] for frame in inspect.stack()[1:]] instructions.program.curr_tape.if_states.append(state) def else_then(): @@ -1531,6 +1404,8 @@ def decorator(body): def if_e(condition): """ Conditional execution with else block. + Use :py:class:`~Compiler.types.MemValue` to assign values that + live beyond. :param condition: regint/cint/int @@ -1538,12 +1413,13 @@ def if_e(condition): .. code:: + y = MemValue(0) @if_e(x > 0) def _(): - ... + y.write(1) @else_ def _(): - ... + y.write(0) """ try: condition = bool(condition) @@ -1647,11 +1523,18 @@ def get_player_id(): return res def listen_for_clients(port): - """ Listen for clients on specific port. """ + """ Listen for clients on specific port base. + + :param port: port base (int/regint/cint) + """ instructions.listen(regint.conv(port)) def accept_client_connection(port): - """ Listen for clients on specific port. """ + """ Accept client connection on specific port base. + + :param port: port base (int/regint/cint) + :returns: client id + """ res = regint() instructions.acceptclientconnection(res, regint.conv(port)) return res @@ -1779,7 +1662,9 @@ def block(i): return (sign_a * sign_b) * A def IntDiv(a, b, k, kappa=None): - return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k, + l = 2 * k + 1 + b = a.conv(b) + return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k, kappa, nearest=True) @instructions_base.ret_cisc @@ -1801,24 +1686,25 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): theta = int(ceil(log(k/3.5) / log(2))) base.set_global_vector_size(b.size) - alpha = b.get_type(2 * k).two_power(2*f) + alpha = b.get_type(2 * k).two_power(2*f, size=b.size) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() - y = a.extend(2 *k) * w - y = y.round(2*k, f, kappa, nearest, signed=True) + l_y = k + 3 * f - res_f + y = a.extend(l_y) * w + y = y.round(l_y, f, kappa, nearest, signed=True) for i in range(theta - 1): x = x.extend(2 * k) - y = y.extend(2 * k) * (alpha + x).extend(2 * k) + y = y.extend(l_y) * (alpha + x).extend(l_y) x = x * x - y = y.round(2*k, 2*f, kappa, nearest, signed=True) + y = y.round(l_y, 2*f, kappa, nearest, signed=True) x = x.round(2*k, 2*f, kappa, nearest, signed=True) x = x.extend(2 * k) - y = y.extend(2 * k) * (alpha + x).extend(2 * k) - y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) + y = y.extend(l_y) * (alpha + x).extend(l_y) + y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True) return y def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False): diff --git a/Compiler/ml.py b/Compiler/ml.py index 42389ae83..c7f09d09f 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -73,8 +73,13 @@ def log_e(x): return mpc_math.log_fx(x, math.e) +use_mux = False + def exp(x): - return mpc_math.pow_fx(math.e, x) + if use_mux: + return mpc_math.mux_exp(math.e, x) + else: + return mpc_math.pow_fx(math.e, x) def get_limit(x): exp_limit = 2 ** (x.k - x.f - 1) @@ -104,7 +109,7 @@ def sigmoid_prime(x): @vectorize def approx_sigmoid(x, n=3): """ Piece-wise approximate sigmoid as in - `Dahl et al. `_ + `Hong et al. `_ :param x: input :param n: number of pieces, 3 (default) or 5 @@ -148,13 +153,36 @@ def argmax(x): """ Compute index of maximum element. :param x: iterable - :returns: sint + :returns: sint or 0 if :py:obj:`x` has length 1 """ def op(a, b): comp = (a[1] > b[1]) return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1]) return tree_reduce(op, enumerate(x))[0] +def softmax(x): + """ Softmax. + + :param x: vector or list of sfix + :returns: sfix vector + """ + return softmax_from_exp(exp_for_softmax(x)[0]) + +def exp_for_softmax(x): + m = util.max(x) - get_limit(x[0]) + math.log(len(x)) + mv = m.expand_to_vector(len(x)) + try: + x = x.get_vector() + except AttributeError: + x = sfix(x) + if use_mux: + return exp(x - mv), m + else: + return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m + +def softmax_from_exp(x): + return x / sum(x) + report_progress = False def progress(x): @@ -188,9 +216,13 @@ def __getitem__(self, *args): self.alloc() return super(Tensor, self).__getitem__(*args) - def assign_vector(self, *args): + def assign_all(self, *args): self.alloc() - return super(Tensor, self).assign_vector(*args) + return super(Tensor, self).assign_all(*args) + + def assign_vector(self, *args, **kwargs): + self.alloc() + return super(Tensor, self).assign_vector(*args, **kwargs) def assign_vector_by_indices(self, *args): self.alloc() @@ -203,6 +235,7 @@ class Layer: thetas = lambda self: () debug_output = False back_batch_size = 128 + print_random_update = False @property def shape(self): @@ -232,11 +265,15 @@ def forward(self, batch=None, training=None): self._forward(batch) def __str__(self): - return type(self).__name__ + str(self._Y.sizes) + return type(self).__name__ + str(self._Y.shape) + + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.Y.shape) class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None output_weights = lambda *args: None + reveal_parameters_to_binary = lambda *args, **kwargs: None nablas = lambda self: () reset = lambda self: None @@ -268,7 +305,8 @@ def __init__(self, N, debug=False, approx=False): self.compute_loss = True self.d_out = 1 - def divisor(self, divisor, size): + @staticmethod + def divisor(divisor, size=1): return cfix(1.0 / divisor, size=size) def _forward(self, batch): @@ -293,7 +331,8 @@ def _(base, size): self.divisor(N, 1)) def eval(self, size, base=0, top=False): - assert not top + if top: + return self.X.get_vector(base, size) > 0 if self.approx: return approx_sigmoid(self.X.get_vector(base, size), self.approx) else: @@ -351,6 +390,36 @@ def _(i): i, truth, guess, b, nabla) return n_correct +class LinearOutput(NoVariableLayer): + n_outputs = -1 + + def __init__(self, N): + self.X = sfix.Array(N) + self.Y = sfix.Array(N) + self.nabla_X = sfix.Array(N) + self.l = MemValue(sfix(0)) + + def _forward(self, batch): + N = len(batch) + guess = self.X.get_vector(0, N) + truth = self.Y.get(batch.get_vector(0, N)) + diff = guess - truth + self.nabla_X.assign_vector(diff) + #print_ln('%s %s %s', diff.reveal(), truth.reveal(), guess.reveal()) + self.l.write(sum((diff) ** 2) * Output.divisor(N)) + + def backward(self, batch): + pass + + def reveal_correctness(*args): + return 0 + + def average_loss(self, N): + return self.l.reveal() + + def eval(self, size, base=0, top=False): + return self.X.get_vector(base, size) + class MultiOutputBase(NoVariableLayer): def __init__(self, N, d_out, approx=False, debug=False): self.X = sfix.Matrix(N, d_out) @@ -439,6 +508,10 @@ def __init__(self, N, d_out, approx=False, debug=False): self.debug = debug self.true_X = sfix.Array(N) + def __repr__(self): + return '%s(%s, %s, approx=%s)' % \ + (type(self).__name__, self.N, self.d_out, self.approx) + def _forward(self, batch): N = len(batch) d_out = self.X.sizes[1] @@ -464,10 +537,7 @@ def _(i): self.losses[i] = -sfix.dot_product( self.Y[batch[i]].get_vector(), log_e(div)) else: - m = util.max(self.X[i]) - mv = m.expand_to_vector(d_out) - x = self.X[i].get_vector() - e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0) + e, m = exp_for_softmax(self.X[i]) self.exp[i].assign_vector(e) if self.compute_loss: true_X = sfix.dot_product(self.Y[batch[i]], self.X[i]) @@ -532,11 +602,8 @@ def _(j): return @for_range_opt_multithread(self.n_threads, len(batch)) def _(i): - for j in range(d_out): - dividend = self.exp[i][j] - divisor = sum(self.exp[i]) - div = (divisor > 0.1).if_else(dividend / divisor, 0) - self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div) + div = softmax_from_exp(self.exp[i]) + self.nabla_X[i][:] = -self.Y[batch[i]][:] + div self.maybe_debug_backward(batch) def maybe_debug_backward(self, batch): @@ -588,17 +655,37 @@ class DenseBase(Layer): nablas = lambda self: (self.nabla_W, self.nabla_b) def output_weights(self): - print_ln('%s', self.W.reveal_nested()) + self.W.print_reveal_nested() print_ln('%s', self.b.reveal_nested()) + def reveal_parameters_to_binary(self, reshape=None): + if reshape: + trans = self.W.transpose() + O = trans.sizes[0] + tmp = MultiArray([O] + reshape, + value_type=self.W.value_type, + address=trans.address) + X, Y, C = reshape + @for_range(O) + def _(i): + @for_range(C) + def _(j): + part = tmp.get_vector_by_indices(i, None, None, j) + part.reveal().binary_output() + else: + self.W.transpose().reveal_to_binary_output() + if self.input_bias: + self.b.reveal_to_binary_output() + def backward_params(self, f_schur_Y, batch): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) + A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address) + B = sfix.Matrix(self.N, self.d_in, address=self.X.address) + @multithread(self.n_threads, self.d_in) def _(base, size): - A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) mp = B.direct_trans_mul(A, reduce=False, indices=(regint.inc(size, base), batch.get_vector(), @@ -608,16 +695,24 @@ def _(base, size): progress('nabla W (matmul)') - if self.d_in * self.d_out < 200000: - print('reduce at once') - @multithread(self.n_threads, self.d_in * self.d_out) - def _(base, size): - self.nabla_W.assign_vector( - tmp.get_vector(base, size).reduce_after_mul(), base=base) - else: - @for_range_opt(self.d_in) - def _(i): - self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul() + @multithread(self.n_threads, self.d_in * self.d_out, + max_size=get_program().budget) + def _(base, size): + self.nabla_W.assign_vector( + tmp.get_vector(base, size).reduce_after_mul(), base=base) + + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % self.d_in + j = regint.get_random(64) % self.d_out + print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s', + str(self.nabla_W), i, j, tmp[i][j].v.reveal(), + self.nabla_W[i][j].reveal(), + A.get_column(j).reveal(), + B.get_column_by_row_indices( + batch.get_vector(), i).reveal()) + print_ln('batch=%s B=%s', batch, + [self.X[bi][0][i].reveal() for bi in batch]) progress('nabla W') @@ -685,15 +780,16 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): self.d_in = d_in self.d_out = d_out self.d = d + self.activation = activation - self.X = MultiArray([N, d, d_in], sfix) - self.Y = MultiArray([N, d, d_out], sfix) + self.X = Tensor([N, d, d_in], sfix) + self.Y = Tensor([N, d, d_out], sfix) self.W = Tensor([d_in, d_out], sfix) self.b = sfix.Array(d_out) back_N = min(N, self.back_batch_size) - self.nabla_Y = MultiArray([back_N, d, d_out], sfix) - self.nabla_X = MultiArray([back_N, d, d_in], sfix) + self.nabla_Y = Tensor([back_N, d, d_out], sfix) + self.nabla_X = Tensor([back_N, d, d_in], sfix) self.nabla_W = sfix.Matrix(d_in, d_out) self.nabla_b = sfix.Array(d_out) @@ -707,12 +803,17 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): else: self.f_input = self.Y + def __repr__(self): + return '%s(%s, %s, %s, activation=%s)' % \ + (type(self).__name__, self.N, self.d_in, + self.d_out, repr(self.activation)) + def reset(self): d_in = self.d_in d_out = self.d_out r = math.sqrt(6.0 / (d_in + d_out)) print('Initializing dense weights in [%f,%f]' % (-r, r)) - self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size())) + self.W.randomize(-r, r, n_threads=self.n_threads) self.b.assign_all(0) def input_from(self, player, raw=False): @@ -796,6 +897,7 @@ def backward(self, compute_nabla_X=True, batch=None): f_schur_Y = nabla_Y if compute_nabla_X: + nabla_X.alloc() @multithread(self.n_threads, N) def _(base, size): B = sfix.Matrix(N, d_out, address=f_schur_Y.address) @@ -806,6 +908,12 @@ def _(base, size): regint.inc(self.d_in))), base) + if self.print_random_update: + print_ln('backward %s', self) + index = regint.get_random(64) % self.nabla_X.total_size() + print_ln('%s nabla_X at %s: %s', str(self.nabla_X), + index, self.nabla_X.to_array()[index].reveal()) + progress('nabla X') self.backward_params(f_schur_Y, batch=batch) @@ -824,8 +932,8 @@ def __init__(self, N, d_in, d_out): self.b = sfix.Array(d_out) self.nabla_b = self.b.same_shape() - self.X = MultiArray([N, 1, d_in], sfix) - self.Y = MultiArray([N, 1, d_out], sfix) + self.X = Tensor([N, 1, d_in], sfix) + self.Y = Tensor([N, 1, d_out], sfix) self.nabla_Y = self.Y.same_shape() def reset(self): @@ -869,13 +977,17 @@ def __init__(self, N, d1, d2=1, alpha=0.5): self.N = N self.d1 = d1 self.d2 = d2 - self.X = MultiArray([N, d1, d2], sfix) - self.Y = MultiArray([N, d1, d2], sfix) - self.nabla_Y = MultiArray([N, d1, d2], sfix) - self.nabla_X = MultiArray([N, d1, d2], sfix) + self.X = Tensor([N, d1, d2], sfix) + self.Y = Tensor([N, d1, d2], sfix) + self.nabla_Y = Tensor([N, d1, d2], sfix) + self.nabla_X = Tensor([N, d1, d2], sfix) self.alpha = alpha self.B = MultiArray([N, d1, d2], sint) + def __repr__(self): + return '%s(%s, %s, alpha=%s)' % \ + (type(self).__name__, self.N, self.d1, self.alpha) + def forward(self, batch, training=False): if training: n_bits = -math.log(self.alpha, 2) @@ -923,7 +1035,7 @@ def f_part(self, base, size): return self.f(self.X.get_part_vector(base, size)) def f_prime_part(self, base, size): - return self.f_prime(self.Y.get_part_vector(base, size)) + return self.f_prime(self.Y.get_vector(base, size)) def _forward(self, batch=[0]): n_per_item = reduce(operator.mul, self.X.sizes[1:]) @@ -1008,47 +1120,68 @@ class MaxPool(NoVariableLayer): def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), padding='VALID'): assert len(shape) == 4 + assert min(shape) > 0, shape for x in strides, ksize: for i in 0, 3: assert x[i] == 1 self.X = Tensor(shape, sfix) if padding == 'SAME': output_shape = [int(math.ceil(shape[i] / strides[i])) for i in range(4)] + padding = [0, 0] else: - output_shape = [(shape[i] - ksize[i]) // strides[i] + 1 for i in range(4)] + if padding == 'VALID': + padding = 0 + if isinstance(padding, int): + padding = [padding, padding] + output_shape = [shape[0]] + [ + (shape[i + 1] + 2 * padding[i] - ksize[i + 1]) // \ + strides [i + 1] + 1 for i in range(2)] + [shape[3]] self.Y = Tensor(output_shape, sfix) self.strides = strides self.ksize = ksize + self.padding = padding self.nabla_X = Tensor(shape, sfix) self.nabla_Y = Tensor(output_shape, sfix) self.N = shape[0] self.comparisons = MultiArray([self.N, self.X.sizes[3], + output_shape[1], output_shape[2], ksize[1] * ksize[2]], sint) - def _forward(self, batch): + def __repr__(self): + return '%s(%s, strides=%s, ksize=%s, padding=%s)' % \ + (type(self).__name__, self.X.sizes, self.strides, + self.ksize, self.padding) + + def forward(self, batch=None, training=False): + if batch is None: + batch = Array.create_from(regint(0)) def process(pool, bi, k, i, j): def m(a, b): c = a[0] > b[0] l = [c * x for x in a[1]] l += [(1 - c) * x for x in b[1]] return c.if_else(a[0], b[0]), l - red = util.tree_reduce(m, [(x[0], [1]) for x in pool]) + red = util.tree_reduce(m, [(x[0], [1] if training else []) + for x in pool]) self.Y[bi][i][j][k] = red[0] - for i, x in enumerate(red[1]): - self.comparisons[bi][k][i] = x + for ii, x in enumerate(red[1]): + self.comparisons[bi][k][i][j][ii] = x self.traverse(batch, process) def backward(self, compute_nabla_X=True, batch=None): if compute_nabla_X: self.nabla_X.alloc() + self.nabla_X.assign_all(0) + break_point() def process(pool, bi, k, i, j): - for (x, h_in, w_in, h, w), c in zip(pool, - self.comparisons[bi][k]): + for (x, h_in, w_in, h, w), c \ + in zip(pool, self.comparisons[bi][k][i][j]): hh = h * h_in ww = w * w_in - self.nabla_X[bi][hh][ww][k] = \ - util.if_else(h_in * w_in, c * self.nabla_Y[bi][i][j][k], - self.nabla_X[bi][hh][ww][k]) + res = h_in * w_in * c * self.nabla_Y[bi][i][j][k] + get_program().protect_memory(True) + self.nabla_X[bi][hh][ww][k] += res + get_program().protect_memory(False) self.traverse(batch, process) def traverse(self, batch, process): @@ -1058,28 +1191,34 @@ def traverse(self, batch, process): [len(batch), self.X.sizes[3]]) def _(l, k): bi = batch[l] + XX = self.X[bi] @for_range_opt(self.Y.sizes[1]) def _(i): - h_base = self.strides[1] * i + h_base = self.strides[1] * i - self.padding[1] + hs = [h_base + jj for jj in range(self.ksize[1])] + if need_padding[1]: + h_ins = [(h < self.X.sizes[1]) * (h >= 0) for h in hs] + else: + h_ins = [True] * self.ksize[1] @for_range_opt(self.Y.sizes[2]) def _(j): - w_base = self.strides[2] * j + w_base = self.strides[2] * j - self.padding[1] pool = [] + ws = [w_base + jj for jj in range(self.ksize[2])] + if need_padding[2]: + w_ins = [(w < self.X.sizes[2]) * (w >= 0) for w in ws] + else: + w_ins = [True] * self.ksize[2] for ii in range(self.ksize[1]): - h = h_base + ii - if need_padding[1]: - h_in = h < self.X.sizes[1] - else: - h_in = True + h = hs[ii] + h_in = h_ins[ii] + XXX = XX[h_in * h] for jj in range(self.ksize[2]): - w = w_base + jj - if need_padding[2]: - w_in = w < self.X.sizes[2] - else: - w_in = True + w = ws[jj] + w_in = w_ins[jj] if not is_zero(h_in * w_in): - pool.append([h_in * w_in * self.X[bi][h_in * h] - [w_in * w][k], h_in, w_in, h, w]) + pool.append([h_in * w_in * XXX[w_in * w][k], + h_in, w_in, h, w]) process(pool, bi, k, i, j) @@ -1090,7 +1229,7 @@ class Argmax(NoVariableLayer): """ def __init__(self, shape): assert len(shape) == 2 - self.X = MultiArray(shape, sfix) + self.X = Tensor(shape, sfix) self.Y = Array(shape[0], sint) def _forward(self, batch=[0]): @@ -1151,7 +1290,7 @@ def _(base, size): self.Y[batch[0]].assign_vector(tmp, base) class FusedBatchNorm(Layer): - """ Fixed-point fused batch normalization layer. + """ Fixed-point fused batch normalization layer (inference only). :param shape: input/output shape (tuple/list of four int) """ @@ -1178,6 +1317,153 @@ def _(i, j): self.X[batch[0]][i][j].get_vector() * self.weights.get_vector() + self.bias.get_vector()) +class BatchNorm(Layer): + """ Fixed-point batch normalization layer. + + :param shape: input/output shape (tuple/list of four int) + :param approx: use approximate square root + + """ + thetas = lambda self: (self.weights, self.bias) + nablas = lambda self: (self.nabla_weights, self.nabla_bias) + + def __init__(self, shape, approx=True, args=None): + assert len(shape) in (2, 3, 4) + if len(shape) == 4: + shape = [shape[0], shape[1] * shape[2], shape[3]] + elif len(shape) == 2: + shape = [shape[0], 1, shape[1]] + tensors = (Tensor(shape, sfix) for i in range(4)) + self.X, self.Y, self.nabla_X, self.nabla_Y = tensors + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.var, self.mu, self.weights, self.bias = arrays + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.mu_hat, self.var_hat, self.nabla_weights, self.nabla_bias = arrays + self.epsilon = 2 ** (-sfix.f * 2 // 3 + 1) + self.momentum = 0.1 + if args != None: + approx = 'precisebn' not in args + self.approx = approx + if approx: + print('Approximate square root inverse in batch normalization') + self.InvertSqrt = mpc_math.InvertSqrt + else: + print('Precise square root inverse in batch normalization') + self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x) + + def __repr__(self): + return '%s(%s, approx=%s)' % \ + (type(self).__name__, self.X.sizes, self.approx) + + def reset(self): + self.bias.assign_all(0) + self.weights.assign_all(1) + self.mu_hat.assign_all(0) + self.var_hat.assign_all(0) + + def _output(self, batch, mu, var): + factor = sfix.Array(len(mu)) + factor[:] = self.InvertSqrt(var[:] + self.epsilon) + @for_range_opt_multithread(self.n_threads, + [len(batch), self.X.sizes[1]]) + def _(i, j): + tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:] + self.Y[i][j][:] = self.bias[:] + tmp + + def forward(self, batch, training=False): + if training: + d = self.X.sizes[1] + d_in = self.X.sizes[2] + s = sfix.Array(d_in) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.X[batch[i]][j].get_vector()) + s.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.mu.assign_vector( + s.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + item = self.X[batch[i]][j].get_vector() + return ((item - self.mu[:]) ** 2) + self.var.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.var.assign_vector( + self.var.get_vector(base, size) / (len(batch) * d - 1), + base) + for x, y, in (self.mu_hat, self.mu), (self.var_hat, self.var): + x[:] = self.momentum * y[:] + (1 - self.momentum) * x[:] + self._output(batch, self.mu, self.var) + if self.print_random_update: + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.mu, self.var: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', + str(self.Y), i, j, k, self.X[i][j][k].reveal(), + self.Y[i][j][k].reveal()) + else: + self._output(batch, self.mu_hat, self.var_hat) + + def backward(self, batch, compute_nabla_X=True): + factor = Array.create_from( + self.InvertSqrt(self.var[:] + self.epsilon)) + mynYf = self.X.same_shape() + gamnY = self.X.same_shape() + gamnYd = self.X.same_shape() + nYdf = self.X.same_shape() + d = self.X.sizes[1] + d_in = self.X.sizes[2] + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + tmp = self.weights[:] * self.nabla_Y[i][j][:] + gamnY[i][j] = tmp + gamnYd[i][j] = tmp * (self.X[i][j][:] - self.mu[:]) + mynYf[i][j] = tmp * factor[:] + nYdf[i][j] = self.nabla_Y[i][j][:] * \ + (self.X[i][j][:] - self.mu[:]) * factor[:] + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.nabla_Y[i][j][:]) + self.nabla_bias.assign(_()) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (nYdf[i][j]) + self.nabla_weights.assign(_()) + factor3 = Array.create_from(factor[:] ** 3) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (mynYf[i][j]) + s1 = Array.create_from(_()) + @multithread(self.n_threads, len(s1)) + def _(base, size): + s1.assign_vector(s1.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (gamnYd[i][j][:] * factor3[:]) + s2 = Array.create_from(_()) + @multithread(self.n_threads, len(s2)) + def _(base, size): + s2.assign_vector( + s2.get_vector(base, size) / (len(batch) * d - 1), base) + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + self.nabla_X[i][j][:] = mynYf[i][j][:] \ + - s1[:] - (self.X[i][j][:] - self.mu[:]) * s2[:] + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.nabla_bias, self.nabla_weights: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k, + self.nabla_Y[i][j][k].reveal(), + self.nabla_X[i][j][k].reveal()) + class QuantBase(object): bias_before_reduction = True @@ -1232,8 +1518,8 @@ def __init__(self, input_shape, output_shape, inputs=None): for x in back_shapes: x[0] = min(x[0], self.back_batch_size) - self.nabla_X = MultiArray(back_shapes[0], self.input_squant) - self.nabla_Y = MultiArray(back_shapes[1], self.output_squant) + self.nabla_X = Tensor(back_shapes[0], self.input_squant) + self.nabla_Y = Tensor(back_shapes[1], self.output_squant) self.inputs = inputs def temp_shape(self): @@ -1284,6 +1570,8 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, self.padding.append(pad_total // 2) elif padding == 'VALID': self.padding = [0, 0] + elif isinstance(padding, int): + self.padding = [padding, padding] else: self.padding = padding @@ -1309,6 +1597,12 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, assert(len(output_shape) == 4) assert(len(weight_shape) == 4) + def __repr__(self): + return '%s(%s, %s, %s, %s, %s, padding=%s, tf_weight_format=%s)' % \ + (type(self).__name__, self.X.sizes, self.weight_shape, + self.bias_shape, self.Y.sizes, self.stride, repr(self.padding), + self.tf_weight_format) + def input_from(self, player, raw=False): self.input_params_from(player) self.weights.input_from(player, budget=100000, raw=raw) @@ -1316,9 +1610,21 @@ def input_from(self, player, raw=False): self.bias.input_from(player, raw=raw) def output_weights(self): - print_ln('%s', self.weights.reveal_nested()) + self.weights.print_reveal_nested() print_ln('%s', self.bias.reveal_nested()) + def reveal_parameters_to_binary(self): + assert not self.tf_weight_format + n_filters = self.weights.shape[0] + n_channels = self.weights.shape[3] + @for_range(n_filters) + def _(i): + @for_range(n_channels) + def _(j): + part = self.weights.get_vector_by_indices(i, None, None, j) + part.reveal().binary_output() + self.bias.reveal_to_binary_output() + def dot_product(self, iv, wv, out_y, out_x, out_c): bias = self.bias[out_c] acc = self.output_squant.unreduced_dot_product(iv, wv) @@ -1479,11 +1785,10 @@ class FixConv2d(Conv2d, FixBase): def reset(self): assert not self.tf_weight_format - kernel_size = self.weight_shape[1] * self.weight_shape[2] - r = math.sqrt(6.0 / (kernel_size * sum(self.weight_shape[::3]))) + n_in = reduce(operator.mul, self.weight_shape[1:]) + r = math.sqrt(6.0 / (n_in + self.weight_shape[0])) print('Initializing convolution weights in [%f,%f]' % (-r, r)) - self.weights.assign_vector( - sfix.get_random(-r, r, size=self.weights.total_size())) + self.weights.randomize(-r, r, n_threads=self.n_threads) self.bias.assign_all(0) def backward(self, compute_nabla_X=True, batch=None): @@ -1531,20 +1836,20 @@ def _(i, j): self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i) if compute_nabla_X: - assert tuple(self.padding) == (0, 0) assert tuple(self.stride) == (1, 1) reverse_weights = MultiArray( [n_channels_in, weights_h, weights_w, n_channels_out], sfix) - @for_range(n_channels_out) - def _(i): + @for_range_opt_multithread(self.n_threads, n_channels_in) + def _(l): @for_range(weights_h) def _(j): @for_range(weights_w) def _(k): - @for_range(n_channels_in) - def _(l): - reverse_weights[l][weights_h-j-1][k][i] = \ - self.weights[i][j][weights_w-k-1][l] + addresses = regint.inc(n_channels_out, + self.weights[0][j][weights_w-k-1].get_address(l), + reduce(operator.mul, self.weights.sizes[1:])) + reverse_weights[l][weights_h-j-1][k].assign_vector( + self.weights.value_type.load_mem(addresses)) padded_w = inputs_w + 2 * padding_w padded_h = inputs_h + 2 * padding_h if padding_h or padding_w: @@ -1565,14 +1870,16 @@ def _(i, j): unreduced_sfix._new(res).reduce_after_mul(), i, None, None, j) if padding_h or padding_w: - @for_range(N) + @for_range_opt_multithread(self.n_threads, N) def _(i): @for_range(inputs_h) def _(j): @for_range(inputs_w) def _(k): + jj = j + padding_w + kk = k + padding_w self.nabla_X[i][j][k].assign_vector( - output[i][j][k].get_vector()) + output[i][jj][kk].get_vector()) if self.debug_output: @for_range(len(batch)) @@ -1717,6 +2024,51 @@ def _(out_y, out_x, c): acc = self.const_div(acc, n) self.Y[0][out_y][out_x][c] = self.output_squant._new(acc) +def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, + padding=0): + """ More convenient interface to :py:class:`FixConv2d`. + + :param input_shape: input shape (tuple/list of four int) + :param out_channels: output channels (int) + :param kernel_size: kernel size (int or tuple/list of two int) + :param stride: stride (int or tuple/list of two int) + :param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int, or tuple/list of two int + + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) + weight_shape = [out_channels] + list(kernel_size) + [input_shape[-1]] + output_shape = [batch_size] + list( + apply_padding(input_shape[1:3], kernel_size, stride, padding)) + \ + [out_channels] + padding = padding.upper() if isinstance(padding, str) \ + else padding + return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape, + stride, padding) + +def easyMaxPool(input_shape, kernel_size, stride=None, padding=0): + """ More convenient interface to :py:class:`MaxPool`. + + :param input_shape: input shape (tuple/list of four int) + :param kernel_size: kernel size (int or tuple/list of two int) + :param stride: stride (int or tuple/list of two int) + :param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int, + or tuple/list of two int + + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) + if stride == None: + stride = kernel_size + padding = padding.upper() if isinstance(padding, str) \ + else padding + return MaxPool(input_shape, [1] + list(stride) + [1], + [1] + list(kernel_size) + [1], padding) + class QuantAveragePool2d(QuantBase, AveragePool2d): def input_params_from(self, player): print('WARNING: assuming that input and output quantization parameters are the same') @@ -1770,9 +2122,15 @@ class Optimizer: """ Base class for graphs of layers. """ n_threads = Layer.n_threads always_shuffle = True + shuffle = True time_layers = False revealing_correctness = False early_division = False + output_diff = False + output_grad = False + output_stats = False + print_accuracy = True + time_training = True @staticmethod def from_args(program, layers): @@ -1780,22 +2138,33 @@ def from_args(program, layers): res = Adam(layers, 1, approx='adamapprox' in program.args) elif 'amsgrad' in program.args: res = Adam(layers, approx=True, amsgrad=True) + elif 'amsgradprec' in program.args: + res = Adam(layers, approx=False, amsgrad=True) elif 'quotient' in program.args: res = Adam(layers, approx=True, amsgrad=True, normalize=True) else: res = SGD(layers, 1) res.early_division = 'early_div' in program.args + res.output_diff = 'output_diff' in program.args + res.output_grad = 'output_grad' in program.args + res.output_stats = 'output_stats' in program.args return res - def __init__(self, report_loss=None): + def __init__(self, layers=[], report_loss=None): + if get_program().options.binary: + raise CompilerError( + 'machine learning code not compatible with binary circuits') self.tol = 0.000 self.report_loss = report_loss self.X_by_label = None self.print_update_average = False + self.print_random_update = False self.print_losses = False self.print_loss_reduction = False self.i_epoch = MemValue(0) self.stopped_on_loss = MemValue(0) + self.stopped_on_low_loss = MemValue(0) + self.layers = layers @property def layers(self): @@ -1822,6 +2191,10 @@ def set_layers_with_inputs(self, layers): layer.last_used = list(filter(lambda x: x not in used, layer.inputs)) used.update(layer.inputs) + def set_learning_rate(self, lr): + print('Setting learning rate to', lr) + self.gamma = MemValue(cfix(lr)) + def reset(self): """ Initialize weights. """ for layer in self.layers: @@ -1831,6 +2204,7 @@ def reset(self): def batch_for(self, layer, batch): if layer in (self.layers[0], self.layers[-1]): + assert not isinstance(layer, BatchNorm) return batch else: batch = regint.Array(len(batch)) @@ -1861,6 +2235,21 @@ def forward(self, N=None, batch=None, keep_intermediate=True, if i != len(self.layers) - 1 or run_last: layer.forward(batch=self.batch_for(layer, batch), training=training) + if self.print_random_update: + print_ln('forward layer %s', layer) + l = min(100, layer.Y[i].total_size()) + i = regint.get_random(64) % len(batch) + if l < 100: + j = 0 + else: + j = regint.get_random(64) % \ + (layer.Y[i].total_size() - l) + print_ln('forward layer %s at (%s, %s): %s', layer, i, j, + layer.Y[i].to_array().get_vector(j, l).reveal()) + i = regint.get_random(64) % layer.Y[0].total_size() + print_ln('forward layer %s vertical at %s: %s', layer, i, + [layer.Y[j].to_array()[i].reveal() + for j in range(len(batch))]) if self.time_layers: stop_timer(100 + i) break_point() @@ -1903,6 +2292,7 @@ def backward(self, batch): layer.backward(compute_nabla_X=False, batch=self.batch_for(layer, batch)) else: + layer.nabla_X.alloc() layer.backward(batch=self.batch_for(layer, batch)) if len(layer.inputs) == 1: layer.inputs[0].nabla_Y.address = \ @@ -1913,11 +2303,98 @@ def backward(self, batch): if self.time_layers: stop_timer(200 + i) + @classmethod + def stat(cls, name, tensor): + zero, neg, small = (cint.Array(cls.n_threads) for i in range(3)) + s, mx, mn = (cfix.Array(cls.n_threads) for i in range(3)) + for x in zero, neg, small, s, mx, mn: + x.assign_all(0) + total = tensor.total_size() + @multithread(cls.n_threads, total) + def _(base, size): + tn = get_thread_number() - 1 + tmp = Array.create_from( + tensor.get_vector(base, size).reveal()) + @for_range_opt(size, budget=1000) + def _(i): + zero[tn] += tmp[i] == 0 + neg[tn] += tmp[i] < 0 + small[tn] += abs(tmp[i]) < 2 ** (-tmp[i].f / 2) + s[tn] += tmp[i] + mx[tn] = util.max(mx[tn], tmp[i]) + mn[tn] = util.min(mn[tn], tmp[i]) + tmp.delete() + print_str( + ' %s 0:%s/%s, <0:%s/%s, >0:%s/%s, ~0:%s/%s sum:%s max:%s min:%s ', + name, sum(zero), total, sum(neg), total, + total - sum(zero) - sum(neg), total, + sum(small) - sum(zero), total, sum(s), util.max(mx), util.min(mn)) + if len(tensor.shape) == 4: + corners = sum(([tensor[0][i][j][0] for j in (0, -1)] + for i in (0, -1)), []) + elif len(tensor.shape) == 1: + x = tensor.to_array() + corners = [x[i] for i in (0, len(x) // 2 - 1, -1)] + else: + x = tensor[0].to_array() + corners = [x[i] for i in (0, len(x) // 2 - 1, -1)] + print_ln('corners:%s shape:%s', util.reveal(corners), tensor.shape) + + def update(self, i_epoch, i_batch, batch): + if self.output_grad: + @if_(i_batch % 100 == 0) + def _(): + for layer in self.layers[:-1]: + cfix(10000).binary_output() + break_point() + layer.nabla_Y.get_vector(size=2000).reveal().binary_output() + break_point() + for theta, nabla in zip(layer.thetas(), layer.nablas()): + cfix(5000).binary_output() + break_point() + nabla.get_vector().reveal().binary_output() + break_point() + if self.output_stats: + old_params = [] + @if_((i_batch % self.output_stats == 0).bit_or(i_epoch == 0)) + def _(): + for i, layer in enumerate(self.layers[:-1]): + print_ln(layer) + if layer == self.layers[0]: + x = Array.create_from(layer.X.get_slice_vector(batch)) + self.stat(' 0 X', x) + else: + self.stat(' %d X' % i, layer.X) + self.stat(' %d Y' % i, layer.Y) + self.stat(' %d nabla_Y' % i, layer.nabla_Y) + for nabla in layer.nablas(): + self.stat(' %d grad' % i, nabla) + for theta in layer.thetas(): + self.stat(' %d param' % i, theta) + if theta.total_size() < 1000: + old_params.append(theta.get_vector()) + if self.time_layers: + start_timer(1000) + self._update(i_epoch, MemValue(i_batch), batch) + if self.time_layers: + stop_timer(1000) + if self.output_stats: + @if_(i_batch % self.output_stats == 0) + def _(): + for i, layer in enumerate(self.layers[:-1]): + for theta in layer.thetas(): + if theta.total_size() < 1000: + print_ln(layer) + self.stat(' %d diff' % i, Array.create_from( + theta.get_vector() - old_params[0])) + del old_params[0] + @_no_mem_warnings def run(self, batch_size=None, stop_on_loss=0): """ Run training. :param batch_size: batch size (defaults to example size of first layer) + :param stop_on_loss: stop when loss falls below this (default: 0) """ if self.n_epochs == 0: return @@ -1942,8 +2419,13 @@ def _(_): for label, X in enumerate(self.X_by_label): indices = regint.Array(n * n_per_epoch) indices_by_label.append(indices) - indices.assign(regint.inc(len(indices), 0, 1, 1, len(X))) - if self.always_shuffle or n_per_epoch > 1: + indices.assign(regint.inc(len(X))) + missing = len(indices) - len(X) + if missing: + indices.assign_vector( + regint.get_random(int(math.log2(len(X))), size=missing), + base=len(X)) + if self.shuffle and (self.always_shuffle or n_per_epoch > 1): indices.shuffle() loss_sum = MemValue(sfix(0)) self.n_correct.write(0) @@ -1958,7 +2440,7 @@ def _(j): label * n) self.forward(batch=batch, training=True) self.backward(batch=batch) - self.update(i, batch=batch) + self.update(i, j, batch=batch) loss_sum.iadd(self.layers[-1].l) if self.print_loss_reduction: before = self.layers[-1].average_loss(N) @@ -1983,19 +2465,36 @@ def _(j): return res if self.print_losses: print_ln() + self.missing_newline = False if self.report_loss and self.layers[-1].compute_loss and self.layers[-1].approx != 5: print_ln('loss in epoch %s: %s', i, (loss_sum.reveal() * cfix(1 / n_per_epoch))) else: - print_ln('done with epoch %s', i) - time() + print_str('done with epoch %s', i) + if self.time_training or self.print_losses: + print_ln() + else: + print_str('\r') + self.missing_newline = True + if self.time_training: + time() i.iadd(1) res = True if self.tol > 0: - res *= (1 - (loss >= 0) * (loss < self.tol)).reveal() + res *= (1 - (loss_sum >= 0) * \ + (loss_sum < self.tol * n_per_epoch)).reveal() + self.stopped_on_low_loss.write(1 - res) return res - def reveal_correctness(self, data, truth, batch_size): + def reveal_correctness(self, data, truth, batch_size=128, running=False): + """ Test correctness by revealing results. + + :param data: test sample data + :param truth: test labels + :param batch_size: batch size + :param running: output after every batch + + """ N = data.sizes[0] n_correct = MemValue(0) loss = MemValue(sfix(0)) @@ -2006,13 +2505,20 @@ def f(start, batch_size, batch): n_correct.iadd( self.layers[-1].reveal_correctness(batch_size, part_truth)) loss.iadd(self.layers[-1].l * batch_size) - self.run_in_batches(f, data, batch_size) + if running: + total = start + batch_size + print_str('\rpart acc: %s (%s/%s) ', + cfix(n_correct, k=63, f=31) / total, n_correct, total) + self.run_in_batches(f, data, batch_size, truth) + if running: + print_ln() loss = loss.reveal() if cfix.f < 31: loss = cfix._new(loss.v << (31 - cfix.f), k=63, f=31) return n_correct, loss / N def run_in_batches(self, f, data, batch_size, truth=None): + batch_size = min(batch_size, data.sizes[0]) training_data = self.layers[0].X.address training_truth = self.layers[-1].Y.address self.layers[0].X.address = data.address @@ -2027,43 +2533,54 @@ def _(i): batch_size = N % batch_size if batch_size: start = N - batch_size - f(start, batch_size, batch) + f(start, batch_size, regint.Array(batch_size)) self.layers[0].X.address = training_data self.layers[-1].Y.address = training_truth @_no_mem_warnings def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, - acc_batch_size=None): + acc_batch_size=None, reset=True): if acc_batch_size is None: acc_batch_size = batch_size depreciation = None + if program is None: + class A: + pass + program = A() + program.args = [] for arg in program.args: m = re.match('rate(.*)', arg) if m: - self.gamma = MemValue(cfix(float(m.group(1)))) + self.set_learning_rate(float(m.group(1))) m = re.match('dep(.*)', arg) if m: depreciation = float(m.group(1)) if 'nomom' in program.args: self.momentum = 0 - self.print_losses = 'print_losses' in program.args + self.print_losses |= 'print_losses' in program.args + self.print_random_update = 'print_random_update' in program.args + Layer.print_random_update = self.print_random_update self.time_layers = 'time_layers' in program.args - self.revealing_correctness = not 'no_acc' in program.args + self.revealing_correctness &= not 'no_acc' in program.args self.layers[-1].compute_loss = not 'no_loss' in program.args + if 'full_cisc' in program.args: + program.options.keep_cisc = 'FPDiv,exp2_fx,log2_fx' model_input = 'model_input' in program.args acc_first = model_input and not 'train_first' in program.args if model_input: for layer in self.layers: layer.input_from(0) - else: + elif reset: self.reset() if 'one_iter' in program.args: + print_float_prec(16) self.output_weights() print_ln('loss') - print_ln('%s', self.eval( - self.layers[0].X.get_part(0, batch_size)).reveal_nested()) + self.eval( + self.layers[0].X.get_part(0, batch_size), + batch_size=batch_size).print_reveal_nested() for layer in self.layers: - print_ln('%s', layer.X.get_part(0, batch_size).reveal_nested()) + layer.X.get_part(0, batch_size).print_reveal_nested() print_ln('%s', self.layers[-1].Y.get_part(0, batch_size).reveal_nested()) batch = Array.create_from(regint.inc(batch_size)) self.forward(batch=batch, training=True) @@ -2072,35 +2589,55 @@ def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, print_ln('loss %s', self.layers[-1].l.reveal()) self.output_weights() return + if 'bench10' in program.args or 'bench1' in program.args: + n = 1 if 'bench1' in program.args else 10 + print('benchmarking %s iterations' % n) + @for_range(n) + def _(i): + batch = Array.create_from(regint.inc(batch_size)) + self.forward(batch=batch, training=True) + self.backward(batch=batch) + self.update(0, batch=batch) + return @for_range(n_runs) def _(i): if not acc_first: - start_timer(1) + if self.time_training: + start_timer(1) self.run(batch_size, stop_on_loss=0 if 'no_loss' in program.args else 100) - stop_timer(1) + if self.time_training: + stop_timer(1) if 'no_acc' in program.args: return N = self.layers[0].X.sizes[0] n_trained = (N + batch_size - 1) // batch_size * batch_size - print_ln('train_acc: %s (%s/%s)', - cfix(self.n_correct, k=63, f=31) / n_trained, - self.n_correct, n_trained) + if not acc_first and self.print_accuracy and \ + self.revealing_correctness: + print_ln('train_acc: %s (%s/%s)', + cfix(self.n_correct, k=63, f=31) / n_trained, + self.n_correct, n_trained) if test_X and test_Y: + print('use test set') n_test = len(test_Y) - n_correct, loss = self.reveal_correctness(test_X, test_Y, - acc_batch_size) + n_correct, loss = self.reveal_correctness( + test_X, test_Y, acc_batch_size, + running='part_acc' in program.args) print_ln('test loss: %s', loss) - print_ln('acc: %s (%s/%s)', - cfix(n_correct, k=63, f=31) / n_test, - n_correct, n_test) + if self.print_accuracy: + print_ln('acc: %s (%s/%s)', + cfix(n_correct, k=63, f=31) / n_test, + n_correct, n_test) if acc_first: - start_timer(1) + if self.time_training: + start_timer(1) self.run(batch_size) - stop_timer(1) + if self.time_training: + stop_timer(1) else: - @if_(util.or_op(self.stopped_on_loss, n_correct < - int(n_test // self.layers[-1].n_outputs * 1.2))) + @if_(util.or_op(self.stopped_on_loss, (n_correct < + int(n_test // self.layers[-1].n_outputs * 1.2)) + if test_X and test_Y else 0)) def _(): self.gamma.imul(.5) if 'crash' in program.args: @@ -2113,14 +2650,60 @@ def _(): if depreciation: self.gamma.imul(depreciation) print_ln('reducing learning rate to %s', self.gamma) + return 1 - self.stopped_on_low_loss + if self.missing_newline: + print_ln('') if 'model_output' in program.args: self.output_weights() + def fit(self, X, Y, epochs=1, batch_size=128, validation_data=(None, None), + program=None, reset=True, print_accuracy=False, print_loss=False): + """ Train model. + + :param X: training sample data (sfix tensor) + :param Y: training labels (sint/sfix tensor) + :param epochs: number of epochs (int) + :param batch_size: batch size (int) + :param validation_data: tuple of test sample data and labels for + accuracy testing (optional; reveals labels) + :param program: :py:class:`~Compile.program.Program` instance to use + command-line parameters (optional) + :param reset: whether to initialize model + :param print_accuracy: print accuracy on training data (reveals labels) + :param print_loss: reveal and print training loss after every batch + + """ + self.layers[0].X = X + self.layers[-1].Y = Y + self.revealing_correctness = print_accuracy + self.print_losses = print_loss + self.time_training = False + self.run_by_args(program, epochs, batch_size, *validation_data, + reset=reset) + def output_weights(self): print_float_precision(max(6, sfix.f // 3)) for layer in self.layers: layer.output_weights() + def summary(self): + sizes = [var.total_size() for var in self.thetas] + print(sizes) + print('Trainable params:', sum(sizes)) + + @property + def trainable_variables(self): + return list(self.thetas) + + def reveal_model_to_binary(self): + input_shape = self.layers[0].X.shape + for layer in self.layers: + if len(input_shape) == 4 and isinstance(layer, DenseBase): + layer.reveal_parameters_to_binary(reshape=input_shape[1:]) + else: + layer.reveal_parameters_to_binary() + input_shape = layer.Y.shape + class Adam(Optimizer): """ Adam/AMSgrad optimizer. @@ -2130,7 +2713,8 @@ class Adam(Optimizer): """ def __init__(self, layers, n_epochs=1, approx=False, amsgrad=False, normalize=False): - self.gamma = MemValue(cfix(.001)) + super(Adam, self).__init__() + self.set_learning_rate(.001) self.beta1 = 0.9 self.beta2 = 0.999 self.beta1_power = MemValue(cfix(1)) @@ -2141,15 +2725,15 @@ def __init__(self, layers, n_epochs=1, approx=False, amsgrad=False, self.amsgrad = amsgrad self.normalize = normalize if amsgrad: - print_str('Using AMSgrad ') + print_both('Using AMSgrad ', end='') else: - print_str('Using Adam ') + print_both('Using Adam ', end='') if approx: - print_ln('with inverse square root approximation') + print_both('with inverse square root approximation') else: - print_ln('with more precise inverse square root') + print_both('with more precise inverse square root') if normalize: - print_ln('Normalize gradient') + print_both('Normalize gradient') self.layers = layers self.ms = [] @@ -2164,9 +2748,7 @@ def __init__(self, layers, n_epochs=1, approx=False, amsgrad=False, if amsgrad: self.vhats.append(nabla.same_shape()) - super(Adam, self).__init__() - - def update(self, i_epoch, batch): + def _update(self, i_epoch, i_batch, batch): self.beta1_power *= self.beta1 self.beta2_power *= self.beta2 m_factor = MemValue(1 / (1 - self.beta1_power)) @@ -2182,7 +2764,8 @@ def _(base, size): util.max, abs_g.get_vector()) scale = MemValue(sfix._new(library.AppRcr( max_g.v, max_g.k, max_g.f, simplex_flag=True))) - @multithread(self.n_threads, m.total_size()) + @multithread(self.n_threads, m.total_size(), + max_size=get_program().budget) def _(base, size): m_part = m.get_vector(base, size) v_part = v.get_vector(base, size) @@ -2193,20 +2776,30 @@ def _(base, size): v_part = self.beta2 * v_part + (1 - self.beta2) * g_part ** 2 m.assign_vector(m_part, base) v.assign_vector(v_part, base) + mhat = m_part * m_factor.expand_to_vector(size) + vhat = v_part * v_factor.expand_to_vector(size) if self.amsgrad: - vhat = self.vhats [i_layer].get_vector(base, size) - vhat = util.max(vhat, v_part) + v_max = self.vhats [i_layer].get_vector(base, size) + vhat = util.max(vhat, v_max) self.vhats[i_layer].assign_vector(vhat, base) - diff = self.gamma.expand_to_vector(size) * m_part - else: - mhat = m_part * m_factor.expand_to_vector(size) - vhat = v_part * v_factor.expand_to_vector(size) - diff = self.gamma.expand_to_vector(size) * mhat + diff = self.gamma.expand_to_vector(size) * mhat if self.approx: diff *= mpc_math.InvertSqrt(vhat + self.epsilon ** 2) else: diff /= mpc_math.sqrt(vhat) + self.epsilon theta.assign_vector(theta.get_vector(base, size) - diff, base) + if self.output_diff: + @if_(i_batch % 100 == 0) + def _(): + diff.reveal().binary_output() + if self.output_stats and m.total_size() < 1000: + @if_(i_batch % self.output_stats == 0) + def _(): + self.stat('g', g) + self.stat('m', m) + self.stat('v', v) + self.stat('vhat', self.vhats[i_layer]) + self.stat('theta', theta) class SGD(Optimizer): """ Stochastic gradient descent. @@ -2215,7 +2808,8 @@ class SGD(Optimizer): :param n_epochs: number of epochs for training :param report_loss: disclose and print loss """ - def __init__(self, layers, n_epochs, debug=False, report_loss=None): + def __init__(self, layers, n_epochs=1, debug=False, report_loss=None): + super(SGD, self).__init__(report_loss=report_loss) self.momentum = 0.9 self.layers = layers self.n_epochs = n_epochs @@ -2225,9 +2819,9 @@ def __init__(self, layers, n_epochs, debug=False, report_loss=None): self.nablas.extend(layer.nablas()) for theta in layer.thetas(): self.delta_thetas.append(theta.same_shape()) - self.gamma = MemValue(cfix(0.01)) + self.set_learning_rate(0.01) self.debug = debug - super(SGD, self).__init__(report_loss) + print_both('Using SGD') @_no_mem_warnings def reset(self, X_by_label=None): @@ -2247,7 +2841,7 @@ def _(i): y.assign_all(0) super(SGD, self).reset() - def update(self, i_epoch, batch): + def _update(self, i_epoch, i_batch, batch): for nabla, theta, delta_theta in zip(self.nablas, self.thetas, self.delta_thetas): @multithread(self.n_threads, nabla.total_size()) @@ -2304,20 +2898,35 @@ def _(i): print_ln_if((x > limit) + (x < -limit), 'theta epoch=%s %s index=%s %s', i_epoch.read(), str(theta), i, x) - index = regint.get_random(64) % len(a) - print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index, - aa[1][index], aa[0][index], aa[2][index]) + if self.print_random_update: + print_ln('update') + l = min(100, nabla.total_size()) + if l < 100: + index = 0 + else: + index = regint.get_random(64) % (nabla.total_size() - l) + print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), + index, nabla.to_array().get_vector(index, l).reveal(), + delta_theta.to_array().get_vector(index, l).reveal(), + theta.to_array().get_vector(index, l).reveal()) self.gamma.imul(1 - 10 ** - 6) def apply_padding(input_shape, kernel_size, strides, padding): - if padding == 'valid': - return (input_shape[0] - kernel_size[0] + 1) // strides[0], \ + if isinstance(padding, int): + padding = [padding, padding] + if isinstance(padding, (tuple, list)): + input_shape = [x + sum(padding) for x in input_shape] + padding = 'valid' + if padding.lower() == 'valid': + res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \ (input_shape[1] - kernel_size[1] + 1) // strides[1], - elif padding == 'same': - return (input_shape[1]) // strides[0], \ - (input_shape[2]) // strides[1], + assert min(res) > 0, (input_shape, kernel_size, strides, padding) + return res + elif padding.lower() == 'same': + return (input_shape[0]) // strides[0], \ + (input_shape[1]) // strides[1], else: - raise Exception('invalid padding: ' + padding) + raise Exception('invalid padding: %s' % padding) class keras: class layers: @@ -2325,7 +2934,7 @@ class layers: Dense = lambda *args, **kwargs: ('dense', args, kwargs) def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', - activation=None): + activation=None, input_shape=None): return 'conv2d', {'filters': filters, 'kernel_size': kernel_size, 'strides': strides, 'padding': padding, 'activation': activation} @@ -2340,6 +2949,13 @@ def Dropout(rate): raise Exception('rate needs to be a power of two') return 'dropout', rate + def Activation(activation): + assert(activation == 'relu') + return activation, + + def BatchNormalization(): + return 'batchnorm', + class optimizers: SGD = lambda *args, **kwargs: ('sgd', args, kwargs) Adam = lambda *args, **kwargs: ('adam', args, kwargs) @@ -2354,15 +2970,30 @@ def __init__(self, layers): def compile(self, optimizer): self.optimizer = optimizer + def compile_by_args(self, program): + if 'adam' in program.args: + self.optimizer = 'adam', [], {} + elif 'amsgrad' in program.args: + self.optimizer = 'adam', [], {'amsgrad': True} + elif 'amsgradprec' in program.args: + self.optimizer = 'adam', [], {'amsgrad': True, + 'approx': False} + else: + self.optimizer = 'sgd', [], {} + @property def trainable_variables(self): if self.opt == None: raise Exception('need to run build() or fit() first') return list(self.opt.thetas) + def summary(self): + self.opt.summary() + def build(self, input_shape, batch_size=128): + data_input_shape = input_shape if self.opt != None and \ - input_shape == self.opt.layers[0].X.sizes and \ + input_shape == self.opt.layers[0]._X.sizes and \ batch_size <= self.batch_size and \ type(self.opt).__name__.lower() == self.optimizer[0]: return @@ -2385,12 +3016,11 @@ def build(self, input_shape, batch_size=128): if i == len(self.layers) - 1: if layer[2].get('activation', 'softmax') in \ ('softmax', 'sigmoid'): - del layer[2]['activation'] + layer[2].pop('activation', None) layers.append(Dense(N, n_units, layer[1][0], **layer[2])) + input_shape = layers[-1].Y.sizes elif name == 'conv2d': - if len(layers) != 0: - input_shape = layers[-1].Y.sizes input_shape = list(input_shape) + \ [1] * (4 - len(input_shape)) print (layer[1]) @@ -2398,44 +3028,38 @@ def build(self, input_shape, batch_size=128): filters = layer[1]['filters'] strides = layer[1]['strides'] padding = layer[1]['padding'] - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - if isinstance(strides, int): - strides = (strides, strides) - weight_shape = [filters] + list(kernel_size) + \ - [input_shape[-1]] - output_shape = [batch_size] + list( - apply_padding(input_shape[1:3], kernel_size, - strides, padding)) + [filters] - layers.append(FixConv2d(input_shape, weight_shape, - (filters,), output_shape, - strides, padding.upper())) + layers.append(easyConv2d( + input_shape, batch_size, filters, kernel_size, + strides, padding)) + output_shape = layers[-1].Y.sizes + input_shape = output_shape + print('conv output shape', output_shape) elif name == 'maxpool': pool_size = layer[1]['pool_size'] strides = layer[1]['strides'] padding = layer[1]['padding'] - if isinstance(pool_size, int): - pool_size = (pool_size, pool_size) - if isinstance(strides, int): - strides = (strides, strides) - if strides == None: - strides = pool_size - layers.append(MaxPool(layers[-1].Y.sizes, - [1] + list(strides) + [1], - [1] + list(pool_size) + [1], - padding.upper())) + layers.append(easyMaxPool(input_shape, pool_size, + strides, padding)) + input_shape = layers[-1].Y.sizes elif name == 'dropout': layers.append(Dropout(batch_size, reduce( operator.mul, layers[-1].Y.sizes[1:]), alpha=layer[1])) + input_shape = layers[-1].Y.sizes elif name == 'flatten': pass + elif name == 'relu': + layers.append(Relu(layers[-1].Y.sizes)) + elif name == 'batchnorm': + input_shape = layers[-1].Y.sizes + layers.append(BatchNorm(layers[-1].Y.sizes)) else: raise Exception(layer[0] + ' not supported') if layers[-1].d_out == 1: - layers.append(Output(input_shape[0])) + layers.append(Output(data_input_shape[0])) else: - layers.append(MultiOutput(input_shape[0], layers[-1].d_out)) + layers.append( + MultiOutput(data_input_shape[0], layers[-1].d_out)) if self.optimizer[1]: raise Exception('use keyword arguments for optimizer') opt = self.optimizer[0] @@ -2447,7 +3071,7 @@ def build(self, input_shape, batch_size=128): opt.momentum = momentum elif opt == 'adam': opt = Adam(layers, amsgrad=opts.pop('amsgrad', None), - approx=True) + approx=opts.pop('approx', True)) beta1 = opts.pop('beta_1', None) beta2 = opts.pop('beta_2', None) epsilon = opts.pop('epsilon', None) @@ -2467,7 +3091,7 @@ def build(self, input_shape, batch_size=128): raise Exception(opt + ' not supported') lr = opts.pop('learning_rate', None) if lr != None: - opt.gamma = MemValue(cfix(lr)) + opt.set_learning_rate(lr) if opts: raise Exception(opts + ' not supported') self.batch_size = batch_size @@ -2476,17 +3100,17 @@ def build(self, input_shape, batch_size=128): def fit(self, x, y, batch_size, epochs=1, validation_data=None): assert len(x) == len(y) self.build(x.sizes, batch_size) - if x.total_size() != self.opt.layers[0].X.total_size(): + if x.total_size() != self.opt.layers[0]._X.total_size(): raise Exception('sample data size mismatch') if y.total_size() != self.opt.layers[-1].Y.total_size(): - print (y, layers[-1].Y) + print (y, self.opt.layers[-1].Y) raise Exception('label size mismatch') if validation_data == None: validation_data = None, None else: if len(validation_data[0]) != len(validation_data[1]): raise Exception('test set size mismatch') - self.opt.layers[0].X.address = x.address + self.opt.layers[0]._X.address = x.address self.opt.layers[-1].Y.address = y.address self.opt.run_by_args(get_program(), epochs, batch_size, validation_data[0], validation_data[1], @@ -2500,16 +3124,219 @@ def predict(self, x, batch_size=None): batch_size = min(batch_size, self.batch_size) return self.opt.eval(x, batch_size=batch_size) -def solve_linear(A, b, n_iterations, progress=False): - """ Iterative linear solution approximation. """ +def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None): + """ Convert a PyTorch Sequential object to MP-SPDZ layers. + + :param sequence: PyTorch Sequential object + :param data_input_shape: input shape (list of four int) + :param batch_size: batch size (int) + :param input_via: player to input model data via (default: don't) + + """ + layers = [] + + def mul(x): + return reduce(operator.mul, x) + + def process(item): + nonlocal input_shape + name = type(item).__name__ + if name == 'Sequential': + for x in item: + process(x) + elif name == 'Linear': + assert mul(input_shape[1:]) == item.in_features + assert item.bias is not None + layers.append(Dense(input_shape[0], item.in_features, + item.out_features)) + if input_via is not None: + shapes = [x.shape for x in (layers[-1].W, layers[-1].b)] + import numpy + swapped = item.weight.detach().numpy() + if len(input_shape) == 4: + print (swapped.shape) + swapped = numpy.reshape( + swapped, + [item.out_features, input_shape[3]] + input_shape[1:3]) + print (swapped.shape) + swapped = numpy.moveaxis(swapped, 1, -1) + print (swapped.shape) + swapped = numpy.reshape( + swapped, [item.out_features, item.in_features]) + print (swapped.shape) + swapped = numpy.swapaxes(swapped, 0, 1) + layers[-1].W = sfix.input_tensor_via( + input_via, swapped) + layers[-1].b = sfix.input_tensor_via( + input_via, item.bias.detach()) + assert layers[-1].W.shape == shapes[0] + assert layers[-1].b.shape == shapes[1] + input_shape = [batch_size, item.out_features] + elif name == 'Conv2d': + layers.append(easyConv2d(input_shape, batch_size, item.out_channels, + item.kernel_size, item.stride, + item.padding)) + input_shape = layers[-1].Y.shape + if input_via is not None: + shapes = [x.shape for x in + (layers[-1].weights, layers[-1].bias)] + import numpy + swapped = numpy.moveaxis( + numpy.array(item.weight.detach()), 1, -1) + layers[-1].weights = sfix.input_tensor_via(input_via, swapped) + layers[-1].bias = sfix.input_tensor_via( + input_via, item.bias.detach()) + assert layers[-1].weights.shape == shapes[0] + assert layers[-1].bias.shape == shapes[1] + elif name == 'MaxPool2d': + layers.append(easyMaxPool(input_shape, item.kernel_size, + item.stride, item.padding)) + input_shape = layers[-1].Y.shape + elif name == 'ReLU': + layers.append(Relu(input_shape)) + elif name == 'Flatten': + pass + elif name == 'BatchNorm2d': + layers.append(BatchNorm(layers[-1].Y.sizes)) + elif name == 'Dropout': + layers.append(Dropout(input_shape[0], mul(layers[-1].Y.sizes[1:]), + alpha=item.p)) + input_shape = layers[-1].Y.sizes + else: + raise CompilerError('unknown PyTorch module: ' + name) + + input_shape = data_input_shape + [1] * (4 - len(data_input_shape)) + process(sequence) + if layers[-1].d_out == 1: + layers.append(Output(data_input_shape[0])) + else: + layers.append(MultiOutput(data_input_shape[0], layers[-1].d_out)) + return layers + +class OneLayerSGD: + def __init__(self, n_epochs=1, batch_size=1, program=None): + self.n_epochs = n_epochs + self.batch_size = batch_size + self.program = program + + def fit(self, X_train, y_train): + """ Train classifier. + + :param X_train: training data (sfix matrix) + :param y_train: training binary labels (sint/sfix array) + + """ + self.init(X_train) + self.opt.fit(X_train, y_train, self.n_epochs, self.batch_size, + program=self.program, print_accuracy=False, + print_loss=False) + + def fit_with_testing(self, X_train, y_train, X_test, y_test): + """ Train classifier with accuracy output after every epoch. + This reveals all labels to simplify the accuracy computation. + + :param X_train: training data (sfix matrix) + :param y_train: training labels (sint/sfix array) + :param X_test: testing data (sfix matrix) + :param y_test: testing labels (sint/sfix array) + + """ + self.init(X_train) + self.opt.print_accuracy = self.print_accuracy + self.opt.fit(X_train, y_train, self.n_epochs, self.batch_size, + validation_data=(X_test, y_test), program=self.program, + print_accuracy=self.print_accuracy, print_loss=True) + + def predict(self, X): + """ Use model for prediction. + + :param X: sample data with row-wise samples (sfix matrix) + :returns: sfix array + + """ + return self.opt.eval(X) + +class SGDLogistic(OneLayerSGD): + """ Logistic regression using SGD. + + :param n_epochs: number of epochs + :param batch_size: batch size + :param program: program object to use command-line options from (default is + not to use any) + + """ + print_accuracy = True + + def init(self, X): + dense = Dense(*X.sizes, 1) + if self.program: + sigmoid = Output.from_args(X.sizes[0], self.program) + self.opt = Optimizer.from_args(self.program, [dense, sigmoid]) + else: + sigmoid = Output(X.sizes[0]) + self.opt = SGD([dense, sigmoid], 1) + + def predict(self, X): + """ Use model to predict labels. + + :param X: sample data with row-wise samples (sfix matrix) + :returns: sint array + + """ + return self.opt.eval(X, top=True) + + def predict_proba(self, X): + """ Use model for probility estimates. + + :param X: sample data with row-wise samples (sfix matrix) + :returns: sfix array + + """ + return super(SGDLogistic, self).predict(X) + +class SGDLinear(OneLayerSGD): + """ Logistic regression using SGD. + + :param n_epochs: number of epochs + :param batch_size: batch size + :param program: program object to use command-line options from (default is + not to use any) + + """ + print_accuracy = False + + def init(self, X): + dense = Dense(*X.sizes, 1) + output = LinearOutput(X.sizes[0]) + if self.program: + self.opt = Optimizer.from_args(self.program, [dense, output]) + else: + self.opt = SGD([dense, output], 1) + +def solve_linear(A, b, n_iterations, progress=False, n_threads=None, + stop=False, already_symmetric=False, precond=False): + """ Iterative linear solution approximation for :math:`Ax=b`. + + :param progress: print some information on the progress (implies revealing) + :param n_threads: number of threads to use + :param stop: whether to stop when converged (implies revealing) + + """ assert len(b) == A.sizes[0] x = sfix.Array(A.sizes[1]) x.assign_vector(sfix.get_random(-1, 1, size=len(x))) - AtA = sfix.Matrix(len(x), len(x)) - AtA[:] = A.direct_trans_mul(A) + if already_symmetric: + AtA = A + r = Array.create_from(b - AtA * x) + else: + AtA = sfix.Matrix(len(x), len(x)) + A.trans_mul_to(A, AtA, n_threads=n_threads) + r = Array.create_from(A.transpose() * b - AtA * x) + if precond: + return solve_linear_diag_precond(AtA, b, x, r, n_iterations, + progress, stop) v = sfix.Array(A.sizes[1]) v.assign_all(0) - r = Array.create_from(A.transpose() * b - AtA * x) Av = sfix.Array(len(x)) @for_range(n_iterations) def _(i): @@ -2523,10 +3350,45 @@ def _(i): if progress: print_ln('%s alpha=%s vr=%s v_norm=%s', i, alpha.reveal(), vr.reveal(), v_norm.reveal()) + if stop: + return (alpha > 0).reveal() + if not already_symmetric: + AtA.delete() return x -def mr(A, n_iterations): - """ Iterative matrix inverse approximation. """ +def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False, + stop=False): + m = 1 / A.diag() + mr = Array.create_from(m * r[:]) + d = Array.create_from(mr) + @for_range(n_iterations) + def _(i): + Ad = A * d + d_norm = sfix.dot_product(d, Ad) + alpha = (d_norm == 0).if_else(0, sfix.dot_product(r, mr) / d_norm) + x[:] = x[:] + alpha * d[:] + r_norm = sfix.dot_product(r, mr) + r[:] = r[:] - alpha * Ad + tmp = m * r[:] + beta = (r_norm == 0).if_else(0, sfix.dot_product(r, tmp) / r_norm) + mr[:] = tmp + d[:] = tmp + beta * d + if progress: + print_ln('%s alpha=%s beta=%s r_norm=%s d_norm=%s', i, + alpha.reveal(), beta.reveal(), r_norm.reveal(), + d_norm.reveal()) + if stop: + return (alpha > 0).reveal() + return x + +def mr(A, n_iterations, stop=False): + """ Iterative matrix inverse approximation. + + :param A: matrix to invert + :param n_iterations: maximum number of iterations + :param stop: whether to stop when converged (implies revealing) + + """ assert len(A.sizes) == 2 assert A.sizes[0] == A.sizes[1] M = A.same_shape() @@ -2536,5 +3398,41 @@ def _(i): e = sfix.Array(n) e.assign_all(0) e[i] = 1 - M[i] = solve_linear(A, e, n_iterations) + M[i] = solve_linear(A, e, n_iterations, stop=stop) return M.transpose() + +def var(x): + """ Variance. """ + mean = MemValue(type(x[0])(0)) + @for_range_opt(len(x)) + def _(i): + mean.iadd(x[i]) + mean /= len(x) + res = MemValue(type(x[0])(0)) + @for_range_opt(len(x)) + def _(i): + res.iadd((x[i] - mean.read()) ** 2) + return res.read() + +def cholesky(A, reveal_diagonal=False): + """ Cholesky decomposition. """ + assert len(A.shape) == 2 + assert A.shape[0] == A.shape[1] + L = A.same_shape() + L.assign_all(0) + @for_range(A.shape[0]) + def _(i): + @for_range(i + 1) + def _(j): + sum = sfix.dot_product(L[i], L[j]) + + @if_e(i == j) + def _(): + L[i][j] = mpc_math.sqrt(A[i][i] - sum) + if reveal_diagonal: + print_ln('L[%s][%s] = %s = sqrt(%s - %s)', i, j, + L[i][j].reveal(), A[i][j].reveal(), sum.reveal()) + @else_ + def _(): + L[i][j] = (1.0 / L[j][j] * (A[i][j] - sum)) + return L diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 322989b34..56d09e1ee 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -8,6 +8,8 @@ import math +import operator +from functools import reduce from Compiler import floatingpoint from Compiler import types from Compiler import comparison @@ -290,12 +292,11 @@ class my_fix(type(a)): # how many bits to use from integer part n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) n_bits = a.f + n_int_bits - sint = types.sint + sint = a.int_type if types.program.options.ring and not as19: intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k if types.program.use_split(): - assert not zero_output from Compiler.GC.types import sbitvec if types.program.use_split() == 3: x = a.v.split_to_two_summands(a.k) @@ -327,6 +328,7 @@ class my_fix(type(a)): s = sint.conv(bits[-1]) lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) higher_bits = bits[a.f:n_bits] + bits_to_check = bits[n_bits:-1] else: if types.program.use_edabit(): l = sint.get_edabit(a.f, True) @@ -338,7 +340,7 @@ class my_fix(type(a)): r_bits = [sint.get_random_bit() for i in range(a.k)] r = sint.bit_compose(r_bits) lower_r = sint.bit_compose(r_bits[:a.f]) - shifted = ((a.v - r) << n_shift).reveal() + shifted = ((a.v - r) << n_shift).reveal(False) masked_bits = (shifted >> n_shift).bit_decompose(a.k) lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], r_bits[a.f-1::-1]) @@ -367,17 +369,17 @@ class my_fix(type(a)): bits = a.v.bit_decompose(a.k, maybe_mixed=True) lower = sint.bit_compose(bits[:a.f]) higher_bits = bits[a.f:n_bits] - s = sint.conv(bits[-1]) + s = a.bit_type.conv(bits[-1]) bits_to_check = bits[n_bits:-1] if not as19: - c = types.sfix._new(lower, k=a.k, f=a.f) + c = a._new(lower, k=a.k, f=a.f) assert(len(higher_bits) == n_bits - a.f) pow2_bits = [sint.conv(x) for x in higher_bits] d = floatingpoint.Pow2_from_bits(pow2_bits) g = exp_from_parts(d, c) - small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, + small_result = a._new(g.v.round(a.f + 2 ** n_int_bits, 2 ** n_int_bits, signed=False, - nearest=types.sfix.round_nearest), + nearest=a.round_nearest), k=a.k, f=a.f) if zero_output: t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), @@ -398,6 +400,36 @@ class my_fix(type(a)): return s.if_else(1 / g, g) +def mux_exp(x, y, block_size=8): + assert util.is_constant_float(x) + from Compiler.GC.types import sbitvec, sbits + bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v + sign = bits[-1] + m = math.log(2 ** (y.k - y.f - 1), x) + del bits[int(math.ceil(math.log(m, 2))) + y.f:] + parts = [] + for i in range(0, len(bits), block_size): + one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v + exp = [] + try: + for j in range(len(one_hot)): + exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f)) + except OverflowError: + pass + exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp)) + bin_part = [0] * max(x.bit_length() for x in exp) + for j in range(len(bin_part)): + for k, (a, b) in enumerate(zip(one_hot, exp)): + bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \ + else 0 + if util.is_zero(bin_part[j]): + bin_part[j] = sbits.get_type(y.size)(0) + if i == 0: + bin_part[j] = sign.if_else(0, bin_part[j]) + parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part)))) + return util.tree_reduce(operator.mul, parts) + + @types.vectorize @instructions_base.sfix_cisc def log2_fx(x, use_division=True): @@ -420,6 +452,8 @@ def log2_fx(x, use_division=True): p -= x.f vlen = x.f v = x._new(v, k=x.k, f=x.f) + elif isinstance(x, (types._register, types.cfix)): + return log2_fx(types.sfix(x), use_division) else: d = types.sfloat(x) v, p, vlen = d.v, d.p, d.vlen @@ -501,7 +535,7 @@ def abs_fx(x): # # @return floored sint value of x def floor_fx(x): - return load_sint(floatingpoint.Trunc(x.v, x.k - x.f, x.f, x.kappa), type(x)) + return load_sint(floatingpoint.Trunc(x.v, x.k, x.f, x.kappa), type(x)) ### sqrt methods @@ -882,7 +916,7 @@ def SqrtComp(z, old=False): k = len(z) if isinstance(z[0], types.sint): return types.sfix._new(sum(z[i] * types.cfix( - 2 ** (-(i - f + 1) / 2)).v for i in range(k))) + 2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k))) k_prime = k // 2 f_prime = f // 2 c1 = types.sfix(2 ** ((f + 1) / 2 + 1)) diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 43e10c2e6..66e82908d 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -1,7 +1,7 @@ from .comparison import * from .floatingpoint import * from .types import * -from . import comparison +from . import comparison, program class NonLinear: kappa = None @@ -30,6 +30,17 @@ def mod2m(self, a, k, m, signed): def trunc_pr(self, a, k, m, signed=True): if isinstance(a, types.cint): return shift_two(a, m) + prog = program.Program.prog + if prog.use_trunc_pr: + if not prog.options.ring: + prog.curr_tape.require_bit_length(k + prog.security) + if signed and prog.use_trunc_pr != -1: + a += (1 << (k - 1)) + res = sint() + trunc_pr(res, a, k, m) + if signed and prog.use_trunc_pr != -1: + res -= (1 << (k - m - 1)) + return res return self._trunc_pr(a, k, m, signed) def trunc_round_nearest(self, a, k, m, signed): @@ -44,6 +55,9 @@ def trunc(self, a, k, m, kappa, signed): return a return self._trunc(a, k, m, signed) + def ltz(self, a, k, kappa=None): + return -self.trunc(a, k, k - 1, kappa, True) + class Masking(NonLinear): def eqz(self, a, k): c, r = self._mask(a, k) @@ -100,42 +114,44 @@ def __init__(self, prime): def _mod2m(self, a, k, m, signed): if signed: a += cint(1) << (k - 1) - return sint.bit_compose(self.bit_dec(a, k, k, True)[:m]) + return sint.bit_compose(self.bit_dec(a, k, m, True)) def _trunc_pr(self, a, k, m, signed): # nearest truncation return self.trunc_round_nearest(a, k, m, signed) def _trunc(self, a, k, m, signed=None): - if signed: - a += cint(1) << (k - 1) - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) - if signed: - res -= cint(1) << (k - 1 - m) - return res + return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed) def trunc_round_nearest(self, a, k, m, signed): a += cint(1) << (m - 1) if signed: a += cint(1) << (k - 1) k += 1 - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) + res = self._trunc(a, k, m, False) if signed: res -= cint(1) << (k - m - 2) return res def bit_dec(self, a, k, m, maybe_mixed=False): assert k < self.prime.bit_length() - bits = BitDecFull(a, maybe_mixed=maybe_mixed) - if len(bits) < m: - raise CompilerError('%d has fewer than %d bits' % (self.prime, m)) - return bits[:m] + bits = BitDecFull(a, m, maybe_mixed=maybe_mixed) + assert len(bits) == m + return bits def eqz(self, a, k): # always signed a += two_power(k) return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True))) + def ltz(self, a, k, kappa=None): + if k + 1 < self.prime.bit_length(): + # https://dl.acm.org/doi/10.1145/3474123.3486757 + # "negative" values wrap around when doubling, thus becoming odd + return self.mod2m(2 * a, k + 1, 1, False) + else: + return super(KnownPrime, self).ltz(a, k, kappa) + class Ring(Masking): """ Non-linear functionality modulo a power of two known at compile time. """ @@ -172,3 +188,6 @@ def trunc_round_nearest(self, a, k, m, signed): return TruncRing(None, tmp + 1, k - m + 1, 1, signed) else: return super(Ring, self).trunc_round_nearest(a, k, m, signed) + + def ltz(self, a, k, kappa=None): + return LtzRing(a, k) diff --git a/Compiler/oram.py b/Compiler/oram.py index 443d826cd..bbaa3938c 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -348,7 +348,7 @@ def __iter__(self): def __len__(self): return 2 + len(self.x) def __repr__(self): - return '{empty=%s}' % self.is_empty if self.is_empty \ + return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \ else '{%s: %s}' % (self.v, self.x) def __add__(self, other): try: @@ -466,12 +466,14 @@ class AbstractORAM(object): def get_array(size, t, *args, **kwargs): return t.dynamic_array(size, t, *args, **kwargs) def read(self, index): - return self._read(self.value_type.hard_conv(index)) + res = self._read(self.index_type.hard_conv(index)) + res = [self.value_type._new(x) for x in res] + return res def write(self, index, value): + value = util.tuplify(value) + value = [self.value_type.conv(x) for x in value] new_value = [self.value_type.get_type(length).hard_conv(v) \ - for length,v in zip(self.entry_size, value \ - if isinstance(value, (tuple, list)) \ - else (value,))] + for length,v in zip(self.entry_size, value)] return self._write(self.index_type.hard_conv(index), *new_value) def access(self, index, new_value, write, new_empty=False): return self._access(self.index_type.hard_conv(index), @@ -795,18 +797,19 @@ def batch_init(self, values): for i,value in enumerate(values): index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.ram[i] = Entry(index, new_value, value_type=self.value_type) class TrivialORAM(RefTrivialORAM, AbstractORAM): """ Trivial ORAM (obviously). """ ref_type = RefTrivialORAM - def __init__(self, size, value_type=sint, value_length=1, index_size=None, \ + def __init__(self, size, value_type=None, value_length=1, index_size=None, \ entry_size=None, contiguous=True, init_rounds=-1): self.index_size = index_size or log2(size) - self.value_type = value_type - self.index_type = value_type.get_type(self.index_size) + self.value_type = value_type or sint + self.index_type = self.value_type.get_type(self.index_size) if entry_size is None: self.value_length = value_length self.entry_size = [None] * value_length @@ -859,15 +862,16 @@ def _read(self, index): empty_entry = self.empty_entry(False) demux_array(bit_decompose(index, self.index_size), \ self.index_vector) + t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size)) @map_sum(get_n_threads(self.size), n_parallel, self.size, \ - self.value_length + 1, [self.value_type.bit_type] + \ - [self.value_type.get_type(l) for l in self.entry_size]) + self.value_length + 1, t) def f(i): entry = self.ram[i] access_here = self.index_vector[i] return access_here * ValueTuple((entry.empty(),) + entry.x) - not_found = f()[0] - read_value = ValueTuple(f()[1:]) + not_found * empty_entry.x + not_found = self.value_type.bit_type(f()[0]) + read_value = ValueTuple(self.value_type.get_type(l)(x) for l, x in zip(self.entry_size, f()[1:])) + \ + not_found * empty_entry.x maybe_stop_timer(6) return read_value, not_found @method_block @@ -876,7 +880,9 @@ def _write(self, index, *new_value): empty_entry = self.empty_entry(False) demux_array(bit_decompose(index, self.index_size), \ self.index_vector) - new_value = make_array(new_value) + new_value = make_array( + new_value, self.value_type.get_type( + max(x or 0 for x in self.entry_size))) @for_range_multithread(get_n_threads(self.size), n_parallel, self.size) def f(i): entry = self.ram[i] @@ -892,7 +898,9 @@ def _access(self, index, write, new_empty, *new_value): empty_entry = self.empty_entry(False) index_vector = \ demux_array(bit_decompose(index, self.index_size)) - new_value = make_array(new_value) + new_value = make_array( + new_value, self.value_type.get_type( + max(x or 0 for x in self.entry_size))) new_empty = MemValue(new_empty) write = MemValue(write) @map_sum(get_n_threads(self.size), n_parallel, self.size, \ @@ -986,7 +994,8 @@ def batch_init(self, values): for i,value in enumerate(values): index = self.value_type.hard_conv(i) new_value = [self.value_type.hard_conv(v) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.__setitem__(index, new_value) def __repr__(self): @@ -1025,8 +1034,9 @@ def get_n_threads_for_tree(size): class TreeORAM(AbstractORAM): """ Tree ORAM. """ - def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ + def __init__(self, size, value_type=None, value_length=1, entry_size=None, \ bucket_oram=TrivialORAM, init_rounds=-1): + value_type = value_type or sint print('create oram of size', size) self.bucket_oram = bucket_oram # heuristic bucket size @@ -1062,11 +1072,12 @@ def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ stop_timer(1) start_timer() self.root = RefBucket(1, self) - self.index = self.index_structure(size, self.D, value_type, init_rounds, True) + self.index = self.index_structure(size, self.D, self.index_type, + init_rounds, True) - self.read_value = Array(self.value_length, value_type) + self.read_value = Array(self.value_length, value_type.default_type) self.read_non_empty = MemValue(self.value_type.bit_type(0)) - self.state = MemValue(self.value_type(0)) + self.state = MemValue(self.value_type.default_type(0)) @method_block def add_to_root(self, state, is_empty, v, *x): if len(x) != self.value_length: @@ -1106,10 +1117,10 @@ def evict2(self, p_bucket1, p_bucket2, d): self.evict_bucket(RefBucket(p_bucket2, self), d) @method_block def read_and_renew_index(self, u): - l_star = random_block(self.D, self.value_type) + l_star = random_block(self.D, self.index_type) if use_insecure_randomness: new_path = regint.get_random(self.D) - l_star = self.value_type(new_path) + l_star = self.index_type(new_path) self.state.write(l_star) return self.index.update(u, l_star, evict=False).reveal() @method_block @@ -1120,7 +1131,7 @@ def read_and_remove_levels(self, u, read_path): parallel = get_parallel(self.index_size, *self.internal_value_type()) @map_sum(get_n_threads_for_tree(self.size), parallel, levels, \ self.value_length + 1, [self.value_type.bit_type] + \ - [self.value_type] * self.value_length) + [self.value_type.default_type] * self.value_length) def process(level): b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level)) bucket = RefBucket(b_index, self) @@ -1142,9 +1153,9 @@ def f(): Program.prog.curr_tape.start_new_basicblock() crash() def internal_value_type(self): - return self.value_type, self.value_length + 1 + return self.value_type.default_type, self.value_length + 1 def internal_entry_size(self): - return self.value_type, [self.D] + list(self.entry_size) + return self.value_type.default_type, [self.D] + list(self.entry_size) def n_buckets(self): return 2**(self.D+1) @method_block @@ -1176,8 +1187,9 @@ def add(self, entry, state=None, evict=True): #print 'pre-add', self maybe_start_timer(4) self.add_to_root(state, entry.empty(), \ - self.value_type(entry.v.read()), \ - *(self.value_type(i.read()) for i in entry.x)) + self.index_type(entry.v.read()), \ + *(self.value_type.default_type(i.read()) + for i in entry.x)) maybe_stop_timer(4) #print 'pre-evict', self if evict: @@ -1221,28 +1233,35 @@ def batch_init(self, values): """ Batch initalization. Obliviously shuffles and adds N entries to random leaf buckets. """ m = len(values) - assert((m & (m-1)) == 0) + if not (m & (m-1)) == 0: + raise CompilerError('Batch size must a power of 2.') if m != self.size: raise CompilerError('Batch initialization must have N values.') if self.value_type != sint: raise CompilerError('Batch initialization only possible with sint.') depth = log2(m) - leaves = [0] * m - entries = [0] * m - indexed_values = [0] * m + leaves = self.value_type.Array(m) + indexed_values = \ + self.value_type.Matrix(m, len(util.tuplify(values[0])) + 1) # assign indices 0, ..., m-1 - for i,value in enumerate(values): + @for_range(m) + def _(i): + value = values[i] index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ for v in (value if isinstance(value, (tuple, list)) \ else (value,))] indexed_values[i] = [index] + new_value - + entries = sint.Matrix(self.bucket_size * 2 ** self.D, + len(Entry(0, list(indexed_values[0]), False))) + # assign leaves - for i,index_value in enumerate(indexed_values): + @for_range(len(indexed_values)) + def _(i): + index_value = list(indexed_values[i]) leaves[i] = random_block(self.D, self.value_type) index = index_value[0] @@ -1252,18 +1271,20 @@ def batch_init(self, values): # save unsorted leaves for position map unsorted_leaves = [MemValue(self.value_type(leaf)) for leaf in leaves] - permutation.sort(leaves, comp=permutation.normal_comparator) + leaves.sort() bucket_sz = 0 # B[i] = (pos, leaf, "last in bucket" flag) for i-th entry - B = [[0]*3 for i in range(m)] + B = sint.Matrix(m, 3) B[0] = [0, leaves[0], 0] B[-1] = [None, None, sint(1)] - s = 0 + s = MemValue(sint(0)) - for i in range(1, m): + @for_range_opt(m - 1) + def _(j): + i = j + 1 eq = leaves[i].equal(leaves[i-1]) - s = (s + eq) * eq + s.write((s + eq) * eq) B[i][0] = s B[i][1] = leaves[i] B[i-1][2] = 1 - eq @@ -1271,7 +1292,7 @@ def batch_init(self, values): #last_in_bucket[i-1] = 1 - eq # shuffle - permutation.shuffle(B, value_type=sint) + B.secure_shuffle() #cint(0).print_reg('shuf') sz = MemValue(0) #cint(0) @@ -1279,7 +1300,8 @@ def batch_init(self, values): empty_positions = Array(nleaves, self.value_type) empty_leaves = Array(nleaves, self.value_type) - for i in range(m): + @for_range(m) + def _(i): if_then(reveal(B[i][2])) #if B[i][2] == 1: #cint(i).print_reg('last') @@ -1291,12 +1313,13 @@ def batch_init(self, values): empty_positions[szval] = B[i][0] #pos[i][0] #empty_positions[szval].reveal().print_reg('ps0') empty_leaves[szval] = B[i][1] #pos[i][1] - sz += 1 + sz.iadd(1) end_if() - pos_bits = [] + pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2) - for i in range(nleaves): + @for_range_opt(nleaves) + def _(i): leaf = empty_leaves[i] # split into 2 if bucket size can't fit into one field elem if self.bucket_size + Program.prog.security > 128: @@ -1315,46 +1338,39 @@ def batch_init(self, values): bucket_bits = [b for sl in zip(bits2,bits) for b in sl] else: bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0] - pos_bits += [[b, leaf] for b in bucket_bits] + assert len(bucket_bits) == self.bucket_size + for j, b in enumerate(bucket_bits): + pos_bits[i * self.bucket_size + j] = [b, leaf] # sort to get empty positions first - permutation.sort(pos_bits, comp=permutation.bitwise_list_comparator) + pos_bits.sort(n_bits=1) # now assign positions to empty entries - empty_entries = [0] * (self.bucket_size*2**self.D - m) - - for i in range(self.bucket_size*2**self.D - m): + @for_range(len(entries) - m) + def _(i): vtype, vlength = self.internal_value_type() leaf = vtype(pos_bits[i][1]) # set leaf in empty entry for assigning after shuffle - value = tuple([leaf] + [vtype(0) for j in range(vlength)]) + value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)]) entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype) - empty_entries[i] = entry + entries[m + i] = entry # now shuffle, reveal positions and place entries - entries = entries + empty_entries - while len(entries) & (len(entries)-1) != 0: - entries.append(None) - permutation.shuffle(entries, value_type=sint) - entries = [entry for entry in entries if entry is not None] - clear_leaves = [MemValue(entry.x[0].reveal()) for entry in entries] + entries.secure_shuffle() + clear_leaves = Array.create_from( + Entry(entries.get_columns()).x[0].reveal()) Program.prog.curr_tape.start_new_basicblock() bucket_sizes = Array(2**self.D, regint) for i in range(2**self.D): bucket_sizes[i] = 0 - k = 0 - for entry,leaf in zip(entries, clear_leaves): - leaf = leaf.read() - k += 1 - - # for some reason leaf_buckets is in bit-reversed order - bits = bit_decompose(leaf, self.D) - rev_leaf = sum(b*2**i for i,b in enumerate(bits[::-1])) - bucket = RefBucket(rev_leaf + (1 << self.D), self) - # hack: 1*entry ensures MemValues are converted to sints - bucket.bucket.ram[bucket_sizes[leaf]] = 1*entry + + @for_range_opt(len(entries)) + def _(k): + leaf = clear_leaves[k] + bucket = RefBucket(leaf + (1 << self.D), self) + bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k]) bucket_sizes[leaf] += 1 self.index.batch_init([leaf.read() for leaf in unsorted_leaves]) @@ -1493,6 +1509,7 @@ def f(i): self.l[i] = [0] * self.elements_per_block time() print_ln('packed ORAM init %s/%s', i, real_init_rounds) + print_ln('packed ORAM init done') print('index initialized, size', size) def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple @@ -1598,16 +1615,20 @@ def __setitem__(self, index, value): def batch_init(self, values): """ Initialize m values with indices 0, ..., m-1 """ m = len(values) - n_entries = max(1, m/self.entries_per_block) - new_values = [0] * n_entries + n_entries = max(1, m//self.entries_per_block) + new_values = sint.Matrix(n_entries, self.elements_per_block) + values = Array.create_from(values) - for i in range(n_entries): + @for_range(n_entries) + def _(i): block = [0] * self.elements_per_block for j in range(self.elements_per_block): base = i * self.entries_per_block + j * self.entries_per_element for k in range(self.entries_per_element): - if base + k < m: - block[j] += values[base + k] << (k * self.entry_size) + @if_(base + k < m) + def _(): + block[j] += \ + values[base + k] << (k * sum(self.entry_size)) new_values[i] = block @@ -1661,13 +1682,51 @@ class OneLevelORAM(TreeORAM): pattern after one recursion. """ index_structure = BaseORAMIndexStructure +class BinaryORAM: + def __init__(self, size, value_type=None, **kwargs): + import circuit_oram + from Compiler.GC import types + n_bits = int(get_program().options.binary) + self.value_type = value_type or types.sbitintvec.get_type(n_bits) + self.index_type = self.value_type + oram_value_type = types.sbits.get_type(64) + if 'entry_size' not in kwargs: + kwargs['entry_size'] = n_bits + self.oram = circuit_oram.OptimalCircuitORAM( + size, value_type=oram_value_type, **kwargs) + self.size = size + def get_index(self, index): + return self.oram.value_type(self.index_type.conv(index).elements()[0]) + def __setitem__(self, index, value): + value = list(self.oram.value_type( + self.value_type.conv(v).elements()[0]) for v in tuplify(value)) + self.oram[self.get_index(index)] = value + def __getitem__(self, index): + value = self.oram[self.get_index(index)] + return untuplify(tuple(self.value_type(v) for v in tuplify(value))) + def read(self, index): + return self.oram.read(index) + def read_and_maybe_remove(self, index): + return self.oram.read_and_maybe_remove(index) + def access(self, *args): + return self.oram.access(*args) + def add(self, *args, **kwargs): + return self.oram.add(*args, **kwargs) + def delete(self, *args, **kwargs): + return self.oram.delete(*args, **kwargs) + def OptimalORAM(size,*args,**kwargs): """ Create an ORAM instance suitable for the size based on experiments. :param size: number of elements - :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` + :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / + :py:class:`sfix` """ + if not util.is_constant(size): + raise CompilerError('ORAM size has be a compile-time constant') + if get_program().options.binary: + return BinaryORAM(size, *args, **kwargs) if optimal_threshold is None: if n_threads == 1: threshold = 2**11 @@ -1716,6 +1775,12 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty): def test_oram(oram_type, N, value_type=sint, iterations=100): stop_grind() oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0) + test_oram_initialized(oram, iterations) + return oram + +def test_oram_initialized(oram, iterations=100): + N = oram.size + value_type = oram.value_type value_type = value_type.get_type(32) index_type = value_type.get_type(log2(N)) start_grind() @@ -1783,7 +1848,7 @@ def test_batch_init(oram_type, N): oram = oram_type(N, value_type) print('initialized') print_reg(cint(0), 'init') - oram.batch_init([value_type(i) for i in range(N)]) + oram.batch_init(Array.create_from(sint(regint.inc(N)))) print_reg(cint(0), 'done') @for_range(N) def f(i): diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index fb1601c3d..b9e3952ba 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -111,24 +111,6 @@ def bucket_size_sorter(x, y): return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z]) -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - - - Returns the network switching config so it may be re-used later. """ - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = permutation.configure_waksman(permutation.random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type(b) for b in c] - permutation.waksman(x, config, reverse=reverse) - permutation.waksman(x, config, reverse=reverse) - - return config - - def LT(a, b): a_bits = bit_decompose(a) b_bits = bit_decompose(b) @@ -472,10 +454,15 @@ def f(): print_ln() # shuffle entries and levels - while len(merged_entries) & (len(merged_entries)-1) != 0: - merged_entries.append(None) #self.root.bucket.empty_entry(False)) - permutation.rec_shuffle(merged_entries, value_type=self.value_type) - merged_entries = [e for e in merged_entries if e is not None] + flat = [] + for x in merged_entries: + flat += list(x[0]) + [x[1]] + flat = self.value_type(flat) + assert len(flat) % len(merged_entries) == 0 + l = len(flat) // len(merged_entries) + shuffled = flat.secure_shuffle(l) + merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]] + for i in range(len(shuffled) // l)] # need to copy entries/levels to memory for re-positioning entries_ram = RAM(self.temp_size, self.entry_type, self.get_array) diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 6e1273ec4..07d3a3e70 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -10,16 +10,6 @@ from Compiler.program import Program _Array = Array -SORT_BITS = [] -insecure_random = Random(0) - -def predefined_comparator(x, y): - """ Assumes SORT_BITS is populated with the required sorting network bits """ - if predefined_comparator.sort_bits_iter is None: - predefined_comparator.sort_bits_iter = iter(SORT_BITS) - return next(predefined_comparator.sort_bits_iter) -predefined_comparator.sort_bits_iter = None - def list_comparator(x, y): """ Uses the first element in the list for comparison """ return x[0] < y[0] @@ -37,10 +27,6 @@ def bitwise_comparator(x, y): def cond_swap_bit(x,y, b): """ swap if b == 1 """ - if x is None: - return y, None - elif y is None: - return x, None if isinstance(x, list): t = [(xi - yi) * b for xi,yi in zip(x, y)] return [xi - ti for xi,ti in zip(x, t)], \ @@ -87,23 +73,6 @@ def odd_even_merge_sort(a, comp=bitwise_comparator): else: raise CompilerError('Length of list must be power of two') -def merge(a, b, comp): - """ General length merge (pads to power of 2) """ - while len(a) & (len(a)-1) != 0: - a.append(None) - while len(b) & (len(b)-1) != 0: - b.append(None) - if len(a) < len(b): - a += [None] * (len(b) - len(a)) - elif len(b) < len(a): - b += [None] * (len(b) - len(b)) - t = a + b - odd_even_merge(t, comp) - for i,v in enumerate(t[::]): - if v is None: - t.remove(None) - return t - def sort(a, comp): """ Pads to power of 2, sorts, removes padding """ length = len(a) @@ -112,47 +81,12 @@ def sort(a, comp): odd_even_merge_sort(a, comp) del a[length:] -def recursive_merge(a, comp): - """ Recursively merge a list of sorted lists (initially sorted by size) """ - if len(a) == 1: - return - # merge smallest two lists, place result in correct position, recurse - t = merge(a[0], a[1], comp) - del a[0] - del a[0] - added = False - for i,c in enumerate(a): - if len(c) >= len(t): - a.insert(i, t) - added = True - break - if not added: - a.append(t) - recursive_merge(a, comp) - -def random_perm(n): - """ Generate a random permutation of length n - - WARNING: randomness fixed at compile-time, this is NOT secure - """ - if not Program.prog.options.insecure: - raise CompilerError('no secure implementation of Waksman permution, ' - 'use --insecure to activate') - a = list(range(n)) - for i in range(n-1, 0, -1): - j = insecure_random.randint(0, i) - t = a[i] - a[i] = a[j] - a[j] = t - return a - -def inverse(perm): - inv = [None] * len(perm) - for i, p in enumerate(perm): - inv[p] = i - return inv +# The following functionality for shuffling isn't used any more as it +# has been moved to the virtual machine. The code has been kept for +# reference. -def configure_waksman(perm): +def configure_waksman(perm, n_iter=[0]): + top = n_iter == [0] n = len(perm) if n == 2: return [(perm[0], perm[0])] @@ -175,6 +109,7 @@ def configure_waksman(perm): via = 0 j0 = j while True: + n_iter[0] += 1 #print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2) i = inv_perm[j] @@ -209,8 +144,11 @@ def configure_waksman(perm): assert sorted(p0) == list(range(n//2)) assert sorted(p1) == list(range(n//2)) - p0_config = configure_waksman(p0) - p1_config = configure_waksman(p1) + p0_config = configure_waksman(p0, n_iter) + p1_config = configure_waksman(p1, n_iter) + if top: + print(n_iter[0], 'iterations for Waksman') + assert O[0] == 0, 'not a Waksman network' return [I + O] + [a+b for a,b in zip(p0_config, p1_config)] def waksman(a, config, depth=0, start=0, reverse=False): @@ -358,23 +296,10 @@ def _(i): # nblocks /= 2 # depth -= 1 -def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False): - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = configure_waksman(random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type.bit_type(b) for b in c] - waksman(x, config, reverse=reverse) - waksman(x, config, reverse=reverse) - -def config_shuffle(n, value_type): - """ Compute config for oblivious shuffling. - - Take mod 2 for active sec. """ - perm = random_perm(n) +def config_from_perm(perm, value_type): + n = len(perm) + assert(list(sorted(perm))) == list(range(n)) if n & (n-1) != 0: # pad permutation to power of 2 m = 2**int(math.ceil(math.log(n, 2))) @@ -394,103 +319,3 @@ def _(i): for j,b in enumerate(c): config[i * len(perm) + j] = b return config - -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - WARNING: This is not a properly secure implementation but has roughly the right complexity. - - Returns the network switching config so it may be re-used later. """ - n = len(x) - m = 2**int(math.ceil(math.log(n, 2))) - assert n == m, 'only working for powers of two' - if config is None: - config = config_shuffle(n, value_type) - - if isinstance(x, list): - if isinstance(x[0], list): - length = len(x[0]) - assert len(x) == length - for i in range(length): - xi = Array(m, value_type.reg_type) - for j in range(n): - xi[j] = x[j][i] - for j in range(n, m): - xi[j] = value_type(0) - iter_waksman(xi, config, reverse=reverse) - iter_waksman(xi, config, reverse=reverse) - for j, y in enumerate(xi): - x[j][i] = y - else: - xa = Array(m, value_type.reg_type) - for i in range(n): - xa[i] = x[i] - for i in range(n, m): - xa[i] = value_type(0) - iter_waksman(xa, config, reverse=reverse) - iter_waksman(xa, config, reverse=reverse) - x[:] = xa - elif isinstance(x, Array): - if len(x) != m and config is None: - raise CompilerError('Non-power of 2 Array input not yet supported') - iter_waksman(x, config, reverse=reverse) - iter_waksman(x, config, reverse=reverse) - else: - raise CompilerError('Invalid type for shuffle:', type(x)) - - return config - -def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None): - """ Shuffle a list of ORAM entries. - - Randomly permutes the first "perm_size" entries, leaving the rest (empty - entry padding) in the same position. """ - n = len(x) - l = len(x[0]) - if n & (n-1) != 0: - raise CompilerError('Entries must be padded to power of two length.') - if perm_size is None: - perm_size = n - - xarrays = [Array(n, value_type.reg_type) for i in range(l)] - for i in range(n): - for j,value in enumerate(x[i]): - if isinstance(value, MemValue): - xarrays[j][i] = value.read() - else: - xarrays[j][i] = value - - if config is None: - config = config_shuffle(perm_size, value_type) - for xi in xarrays: - shuffle(xi, config, value_type, reverse) - for i in range(n): - x[i] = entry_cls(xarrays[j][i] for j in range(l)) - return config - - -def sort_zeroes(bits, x, n_ones, value_type): - """ Return Array of values in "x" where the corresponding bit in "bits" is - a 0. - - The total number of zeroes in "bits" must be known. - "bits" and "x" must be Arrays. """ - config = config_shuffle(len(x), value_type) - shuffle(bits, config=config, value_type=value_type) - shuffle(x, config=config, value_type=value_type) - result = Array(n_ones, value_type.reg_type) - - sz = MemValue(0) - last_x = MemValue(value_type(0)) - #for i,b in enumerate(bits): - #if_then(b.reveal() == 0) - #result[sz.read()] = x[i] - #sz += 1 - #end_if() - @for_range(len(bits)) - def f(i): - found = (bits[i].reveal() == 0) - szval = sz.read() - result[szval] = last_x + (x[i] - last_x) * found - sz.write(sz + found) - last_x.write(result[szval]) - return result diff --git a/Compiler/ppc.py b/Compiler/ppc.py index 70c3e851a..de1ab8a3d 100644 --- a/Compiler/ppc.py +++ b/Compiler/ppc.py @@ -15,7 +15,7 @@ print_float_prec(7) # Use to limit the tester workload -MAX_DATA_LENGTH = 500 +MAX_DATA_LENGTH = 500000 MAX_ML_SIZE = 10000 SECOND_LOOP_SIZE = 1000 @@ -333,5 +333,17 @@ def pint_floordiv(self, other): return to_pint(pint_div(self, other)) +def write_matrix_to_file(matrix, matrix_row, matrix_col): + array_result = Array(matrix_row * matrix_col + 2, matrix.value_type) + array_result[0] = matrix.value_type(matrix_row) + array_result[1] = matrix.value_type(matrix_col) + @for_range(matrix_row) + def _(i): + @for_range(matrix_col) + def _(j): + array_result[2 + i*matrix_col+j] = matrix[i][j] + array_result.write_to_file() + + pint.__mod__ = pint_mod pint.__floordiv__ = pint_floordiv diff --git a/Compiler/program.py b/Compiler/program.py index ebc1601c8..9860053fe 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -4,39 +4,44 @@ object that holds various properties of the computation. """ -from Compiler.config import * -from Compiler.exceptions import * -from Compiler.instructions_base import RegType -import Compiler.instructions -import Compiler.instructions_base -import Compiler.instructions_base as inst_base -from . import allocator as al -from . import util -import random -import time -import sys, os, errno import inspect -from collections import defaultdict, deque import itertools import math -from functools import reduce +import os import re +import sys +import hashlib +from collections import defaultdict, deque +from functools import reduce + +import Compiler.instructions +import Compiler.instructions_base +import Compiler.instructions_base as inst_base +from Compiler.config import REG_MAX, USER_MEM, COST +from Compiler.exceptions import CompilerError +from Compiler.instructions_base import RegType +from . import allocator as al +from . import util data_types = dict( - triple = 0, - square = 1, - bit = 2, - inverse = 3, - dabit = 4, + triple=0, + square=1, + bit=2, + inverse=3, + dabit=4, + mixed=5, + random=6, + open=7, ) field_types = dict( - modp = 0, - gf2n = 1, - bit = 2, + modp=0, + gf2n=1, + bit=2, ) + class defaults: debug = False verbose = False @@ -44,13 +49,15 @@ class defaults: ring = 0 field = 0 binary = 0 + garbled = False prime = None galois = 40 - budget = 100000 + budget = 1000 mixed = False edabit = False + invperm = False split = None - cisc = False + cisc = True comparison = None merge_opens = True preserve_mem_order = False @@ -62,8 +69,9 @@ class defaults: insecure = False keep_cisc = False + class Program(object): - """ A program consists of a list of tapes representing the whole + """A program consists of a list of tapes representing the whole computation. When compiling an :file:`.mpc` file, the single instances is @@ -71,20 +79,22 @@ class Program(object): from Python code, an instance has to be created before running any instructions. """ - def __init__(self, args, options=defaults): - from .non_linear import Ring, Prime, KnownPrime + + def __init__(self, args, options=defaults, name=None): + from .non_linear import KnownPrime, Prime + self.options = options self.verbose = options.verbose self.args = args + self.name = name self.init_names(args) self._security = 40 self.prime = None self.tapes = [] - if sum(x != 0 for x in(options.ring, options.field, - options.binary)) > 1: - raise CompilerError('can only use one out of -B, -R, -F') + if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1: + raise CompilerError("can only use one out of -B, -R, -F") if options.prime and (options.ring or options.binary): - raise CompilerError('can only use one out of -B, -R, -p') + raise CompilerError("can only use one out of -B, -R, -p") if options.ring: self.set_ring_size(int(options.ring)) else: @@ -93,19 +103,20 @@ def __init__(self, args, options=defaults): self.prime = int(options.prime) max_bit_length = int(options.prime).bit_length() - 2 if self.bit_length > max_bit_length: - raise CompilerError('integer bit length can be maximal %s' % - max_bit_length) + raise CompilerError( + "integer bit length can be maximal %s" % max_bit_length + ) self.bit_length = self.bit_length or max_bit_length self.non_linear = KnownPrime(self.prime) else: self.non_linear = Prime(self.security) if not self.bit_length: self.bit_length = 64 - print('Default bit length:', self.bit_length) - print('Default security parameter:', self.security) + print("Default bit length:", self.bit_length) + print("Default security parameter:", self.security) self.galois_length = int(options.galois) if self.verbose: - print('Galois length:', self.galois_length) + print("Galois length:", self.galois_length) self.tape_counter = 0 self._curr_tape = None self.DEBUG = options.debug @@ -118,30 +129,51 @@ def __init__(self, args, options=defaults): self.n_threads = 1 self.public_input_file = None self.types = {} - self.budget = int(self.options.budget) - self.to_merge = [Compiler.instructions.asm_open_class, \ - Compiler.instructions.gasm_open_class, \ - Compiler.instructions.muls_class, \ - Compiler.instructions.gmuls_class, \ - Compiler.instructions.mulrs_class, \ - Compiler.instructions.gmulrs, \ - Compiler.instructions.dotprods_class, \ - Compiler.instructions.gdotprods_class, \ - Compiler.instructions.asm_input_class, \ - Compiler.instructions.gasm_input_class, - Compiler.instructions.inputfix_class, - Compiler.instructions.inputfloat_class, - Compiler.instructions.inputmixed_class, - Compiler.instructions.trunc_pr_class, - Compiler.instructions_base.Mergeable] + if self.options.budget: + self.budget = int(self.options.budget) + else: + if self.options.optimize_hard: + self.budget = 100000 + else: + self.budget = defaults.budget + self.to_merge = [ + Compiler.instructions.asm_open_class, + Compiler.instructions.gasm_open_class, + Compiler.instructions.muls_class, + Compiler.instructions.gmuls_class, + Compiler.instructions.mulrs_class, + Compiler.instructions.gmulrs, + Compiler.instructions.dotprods_class, + Compiler.instructions.gdotprods_class, + Compiler.instructions.asm_input_class, + Compiler.instructions.gasm_input_class, + Compiler.instructions.inputfix_class, + Compiler.instructions.inputfloat_class, + Compiler.instructions.inputmixed_class, + Compiler.instructions.trunc_pr_class, + Compiler.instructions_base.Mergeable, + ] import Compiler.GC.instructions as gc - self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ - gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] + + self.to_merge += [ + gc.ldmsdi, + gc.stmsdi, + gc.ldmsd, + gc.stmsd, + gc.stmsdci, + gc.andrs, + gc.ands, + gc.inputb, + gc.inputbvec, + gc.reveal, + ] self.use_trunc_pr = False """ Setting whether to use special probabilistic truncation. """ self.use_dabit = options.mixed """ Setting whether to use daBits for non-linear functionality. """ self._edabit = options.edabit + """ Whether to use the low-level INVPERM instruction (only implemented with the assumption of a semi-honest two-party environment)""" + self._invperm = options.invperm self._split = False if options.split: self.use_split(int(options.split)) @@ -152,8 +184,14 @@ def __init__(self, args, options=defaults): self.relevant_opts = set() self.n_running_threads = None self.input_files = {} + self.base_addresses = {} + self._protect_memory = False + if not self.options.cisc: + self.options.cisc = not self.options.optimize_hard + Program.prog = self - from . import instructions_base, instructions, types, comparison + from . import comparison, instructions, instructions_base, types + instructions.program = self instructions_base.program = self types.program = self @@ -164,53 +202,53 @@ def get_args(self): return self.args def max_par_tapes(self): - """ Upper bound on number of tapes that will be run in parallel. - (Excludes empty tapes) """ + """Upper bound on number of tapes that will be run in parallel. + (Excludes empty tapes)""" return self.n_threads - + def init_names(self, args): # ignore path to file - source must be in Programs/Source - if 'Programs' in os.listdir(os.getcwd()): + if "Programs" in os.listdir(os.getcwd()): # compile prog in ./Programs/Source directory - self.programs_dir = os.getcwd() + '/Programs' + self.programs_dir = "Programs" else: # assume source is in main SPDZ directory - self.programs_dir = sys.path[0] + '/Programs' + self.programs_dir = sys.path[0] + "/Programs" if self.verbose: - print('Compiling program in', self.programs_dir) - + print("Compiling program in", self.programs_dir) + # create extra directories if needed - for dirname in ['Public-Input', 'Bytecode', 'Schedules']: - if not os.path.exists(self.programs_dir + '/' + dirname): - os.mkdir(self.programs_dir + '/' + dirname) - - progname = args[0].split('/')[-1] - if progname.endswith('.mpc'): - progname = progname[:-4] - - if os.path.exists(args[0]): - self.infile = args[0] - else: - self.infile = self.programs_dir + '/Source/' + progname + '.mpc' + for dirname in ["Public-Input", "Bytecode", "Schedules"]: + if not os.path.exists(self.programs_dir + "/" + dirname): + os.mkdir(self.programs_dir + "/" + dirname) + + if self.name is None: + self.name = args[0].split("/")[-1] + if self.name.endswith(".mpc"): + self.name = self.name[:-4] + + if os.path.exists(args[0]): + self.infile = args[0] + else: + self.infile = self.programs_dir + "/Source/" + self.name + ".mpc" """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames """ if self.options.outfile: - self.name = self.options.outfile + '-' + progname + self.name = self.options.outfile + "-" + self.name else: - self.name = progname + self.name = self.name if len(args) > 1: - self.name += '-' + '-'.join(re.sub('/', '_', arg) - for arg in args[1:]) - self.progname = progname + self.name += "-" + "-".join(re.sub("/", "_", arg) for arg in args[1:]) def set_ring_size(self, ring_size): from .non_linear import Ring + for tape in self.tapes: - prev = tape.req_bit_length['p'] + prev = tape.req_bit_length["p"] if prev and prev != ring_size: - raise CompilerError('cannot have different ring sizes') + raise CompilerError("cannot have different ring sizes") self.bit_length = ring_size - 1 self.non_linear = Ring(ring_size) self.options.ring = str(ring_size) @@ -234,7 +272,8 @@ def g(): :param function: Python function defining the thread :param args: arguments to the function :param name: name used for files - :param single_thread: Boolean indicating whether tape will never be run in parallel to itself + :param single_thread: Boolean indicating whether tape will + never be run in parallel to itself :returns: tape handle """ @@ -258,20 +297,22 @@ def run_tape(self, tape_index, arg): return self.run_tapes([[tape_index, arg]])[0] def run_tapes(self, args): - """ Run tapes in parallel. See :py:func:`new_tape` for an example. + """Run tapes in parallel. See :py:func:`new_tape` for an example. - :param args: list of tape handles or tuples of tape handle and extra argument (for :py:func:`~Compiler.library.get_arg`) + :param args: list of tape handles or tuples of tape handle and extra + argument (for :py:func:`~Compiler.library.get_arg`) :returns: list of thread numbers """ if not self.curr_tape.singular: - raise CompilerError('Compiler does not support ' \ - 'recursive spawning of threads') + raise CompilerError( + "Compiler does not support " "recursive spawning of threads" + ) args = [list(util.tuplify(arg)) for arg in args] singular_tapes = set() for arg in args: if self.tapes[arg[0]].singular: if arg[0] in singular_tapes: - raise CompilerError('cannot run singular tape in parallel') + raise CompilerError("cannot run singular tape in parallel") singular_tapes.add(arg[0]) assert len(arg) assert len(arg) <= 2 @@ -286,73 +327,78 @@ def run_tapes(self, args): else: thread_numbers.append(self.n_threads) self.n_threads += 1 - self.curr_tape.start_new_basicblock(name='pre-run_tape') - Compiler.instructions.run_tape(*sum(([x] + list(y) for x, y in - zip(thread_numbers, args)), [])) - self.curr_tape.start_new_basicblock(name='post-run_tape') + self.curr_tape.start_new_basicblock(name="pre-run_tape") + Compiler.instructions.run_tape( + *sum(([x] + list(y) for x, y in zip(thread_numbers, args)), []) + ) + self.curr_tape.start_new_basicblock(name="post-run_tape") for arg in args: - self.curr_tape.req_node.children.append( - self.tapes[arg[0]].req_tree) + self.curr_tape.req_node.children.append(self.tapes[arg[0]].req_tree) return thread_numbers def join_tape(self, thread_number): self.join_tapes([thread_number]) def join_tapes(self, thread_numbers): - """ Wait for completion of tapes. See :py:func:`new_tape` for an example. + """Wait for completion of tapes. See :py:func:`new_tape` for an example. :param thread_numbers: list of thread numbers """ - self.curr_tape.start_new_basicblock(name='pre-join_tape') + self.curr_tape.start_new_basicblock(name="pre-join_tape") for thread_number in thread_numbers: Compiler.instructions.join_tape(thread_number) self.curr_tape.free_threads.add(thread_number) - self.curr_tape.start_new_basicblock(name='post-join_tape') + self.curr_tape.start_new_basicblock(name="post-join_tape") def update_req(self, tape): if self.req_num is None: self.req_num = tape.req_num else: self.req_num += tape.req_num - + def write_bytes(self): - """ Write all non-empty threads and schedule to files. """ + """Write all non-empty threads and schedule to files.""" nonempty_tapes = [t for t in self.tapes] - sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name - sch_file = open(sch_filename, 'w') - print('Writing to', sch_filename) - sch_file.write(str(self.max_par_tapes()) + '\n') - sch_file.write(str(len(nonempty_tapes)) + '\n') - sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n') - sch_file.write('1 0\n') - sch_file.write('0\n') - sch_file.write(' '.join(sys.argv) + '\n') - req = max(x.req_bit_length['p'] for x in self.tapes) + sch_filename = self.programs_dir + "/Schedules/%s.sch" % self.name + sch_file = open(sch_filename, "w") + print("Writing to", sch_filename) + sch_file.write(str(self.max_par_tapes()) + "\n") + sch_file.write(str(len(nonempty_tapes)) + "\n") + sch_file.write(" ".join("%s:%d" % (tape.name, len(tape)) + for tape in nonempty_tapes) + "\n") + sch_file.write("1 0\n") + sch_file.write("0\n") + sch_file.write(" ".join(sys.argv) + "\n") + req = max(x.req_bit_length["p"] for x in self.tapes) if self.options.ring: - sch_file.write('R:%s' % self.options.ring) + sch_file.write("R:%s" % self.options.ring) elif self.options.prime: - sch_file.write('p:%s' % self.options.prime) + sch_file.write("p:%s" % self.options.prime) else: - sch_file.write('lgp:%s' % req) - sch_file.write('\n') - sch_file.write('opts: %s\n' % ' '.join(self.relevant_opts)) + sch_file.write("lgp:%s" % req) + sch_file.write("\n") + sch_file.write("opts: %s\n" % " ".join(self.relevant_opts)) + sch_file.close() + h = hashlib.sha256() for tape in self.tapes: tape.write_bytes() + h.update(tape.hash) + print('Hash:', h.hexdigest()) def finalize_tape(self, tape): if not tape.purged: tape.optimize(self.options) tape.write_bytes() if self.options.asmoutfile: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) tape.purge() - + @property def curr_tape(self): - """ The tape that is currently running.""" + """The tape that is currently running.""" if self._curr_tape is None: assert not self.tapes self._curr_tape = Tape(self.name, self) @@ -365,13 +411,13 @@ def curr_tape(self, value): @property def curr_block(self): - """ The basic block that is currently being created. """ + """The basic block that is currently being created.""" return self.curr_tape.active_basicblock - + def malloc(self, size, mem_type, reg_type=None, creator_tape=None): - """ Allocate memory from the top """ + """Allocate memory from the top""" if not isinstance(size, int): - raise CompilerError('size must be known at compile time') + raise CompilerError("size must be known at compile time") if size == 0: return if isinstance(mem_type, type): @@ -389,8 +435,7 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): single_size = size size *= self.n_running_threads else: - raise CompilerError('cannot allocate memory ' - 'outside main thread') + raise CompilerError("cannot allocate memory " "outside main thread") blocks = self.free_mem_blocks[mem_type] addr = blocks.pop(size) if addr is not None: @@ -400,21 +445,27 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) - self.allocated_mem_blocks[addr,mem_type] = size + if addr + size >= 2**64: + raise CompilerError("allocation exceeded for type '%s'" % mem_type) + self.allocated_mem_blocks[addr, mem_type] = size if single_size: from .library import get_thread_number, runtime_error_if + tn = get_thread_number() - runtime_error_if(tn > self.n_running_threads, 'malloc') - return addr + single_size * (tn - 1) + runtime_error_if(tn > self.n_running_threads, "malloc") + res = addr + single_size * (tn - 1) + self.base_addresses[str(res)] = addr + return res else: return addr def free(self, addr, mem_type): - """ Free memory """ - if self.curr_block.alloc_pool \ - is not self.curr_tape.basicblocks[0].alloc_pool: - raise CompilerError('Cannot free memory within function block') - size = self.allocated_mem_blocks.pop((addr,mem_type)) + """Free memory""" + if self.curr_block.alloc_pool is not self.curr_tape.basicblocks[0].alloc_pool: + raise CompilerError("Cannot free memory within function block") + if not util.is_constant(addr): + addr = self.base_addresses[str(addr)] + size = self.allocated_mem_blocks.pop((addr, mem_type)) self.free_mem_blocks[mem_type].push(addr, size) def finalize(self): @@ -432,47 +483,60 @@ def finalize(self): if self.options.asmoutfile: for tape in self.tapes: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) def finalize_memory(self): - from . import library - self.curr_tape.start_new_basicblock(None, 'memory-usage') + self.curr_tape.start_new_basicblock(None, "memory-usage") # reset register counter to 0 if not self.options.noreallocate: self.curr_tape.init_registers() - for mem_type,size in sorted(self.allocated_mem.items()): - if size: - #print "Memory of type '%s' of size %d" % (mem_type, size) + for mem_type, size in sorted(self.allocated_mem.items()): + if size and (not self.options.garbled or \ + mem_type not in ('s', 'sg', 'c', 'cg')): + # print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: self.types[mem_type].load_mem(size - 1, mem_type) else: from Compiler.types import _get_type + _get_type(mem_type).load_mem(size - 1, mem_type) if self.verbose: if self.saved: - print('Saved %s memory units through reallocation' % self.saved) + print("Saved %s memory units through reallocation" % self.saved) def public_input(self, x): - """ Append a value to the public input file. """ + """Append a value to the public input file.""" if self.public_input_file is None: - self.public_input_file = open(self.programs_dir + - '/Public-Input/%s' % self.name, 'w') - self.public_input_file.write('%s\n' % str(x)) + self.public_input_file = open( + self.programs_dir + "/Public-Input/%s" % self.name, "w" + ) + self.public_input_file.write("%s\n" % str(x)) + + def get_binary_input_file(self, player): + key = player, 'bin' + if key not in self.input_files: + filename = 'Player-Data/Input-Binary-P%d-0' % player + print('Writing binary data to', filename) + self.input_files[key] = open(filename, 'wb') + return self.input_files[key] def set_bit_length(self, bit_length): - """ Change the integer bit length for non-linear functions. """ + """Change the integer bit length for non-linear functions.""" self.bit_length = bit_length - print('Changed bit length for comparisons etc. to', bit_length) + print("Changed bit length for comparisons etc. to", bit_length) def set_security(self, security): + changed = self._security != security self._security = security self.non_linear.set_security(security) - print('Changed statistical security for comparison etc. to', security) + if changed: + print("Changed statistical security for comparison etc. to", + security) @property def security(self): - """ The statistical security parameter for non-linear - functions. """ + """The statistical security parameter for non-linear + functions.""" return self._security @security.setter @@ -480,7 +544,8 @@ def security(self, security): self.set_security(security) def optimize_for_gc(self): - pass + import Compiler.GC.instructions as gc + self.to_merge += [gc.xors] def get_tape_counter(self): res = self.tape_counter @@ -490,7 +555,7 @@ def get_tape_counter(self): @property def use_trunc_pr(self): if not self._use_trunc_pr: - self.relevant_opts.add('trunc_pr') + self.relevant_opts.add("trunc_pr") return self._use_trunc_pr @use_trunc_pr.setter @@ -498,7 +563,7 @@ def use_trunc_pr(self, change): self._use_trunc_pr = change def use_edabit(self, change=None): - """ Setting whether to use edaBits for non-linear + """Setting whether to use edaBits for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -506,16 +571,30 @@ def use_edabit(self, change=None): """ if change is None: if not self._edabit: - self.relevant_opts.add('edabit') + self.relevant_opts.add("edabit") return self._edabit else: self._edabit = change + def use_invperm(self, change=None): + """ Set whether to use the low-level INVPERM instruction to inverse a permutation (see sint.inverse_permutation). The INVPERM instruction assumes a semi-honest two-party environment. If false, a general protocol implemented in the high-level language is used. + + :param change: change setting if not :py:obj:`None` + :returns: setting if :py:obj:`change` is :py:obj:`None` + """ + if change is None: + if not self._invperm: + self.relevant_opts.add("invperm") + return self._invperm + else: + self._invperm = change + + def use_edabit_for(self, *args): return True def use_split(self, change=None): - """ Setting whether to use local arithmetic-binary share + """Setting whether to use local arithmetic-binary share conversion for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -523,16 +602,16 @@ def use_split(self, change=None): """ if change is None: if not self._split: - self.relevant_opts.add('split') + self.relevant_opts.add("split") return self._split else: if change and not self.options.ring: - raise CompilerError('splitting only supported for rings') - assert change > 1 or change == False + raise CompilerError("splitting only supported for rings") + assert change > 1 or change is False self._split = change def use_square(self, change=None): - """ Setting whether to use preprocessed square tuples + """Setting whether to use preprocessed square tuples (default: false). :param change: change setting if not :py:obj:`None` @@ -556,34 +635,71 @@ def linear_rounds(self, change=None): self._linear_rounds = change def options_from_args(self): - """ Set a number of options from the command-line arguments. """ - if 'trunc_pr' in self.args: + """Set a number of options from the command-line arguments.""" + if "trunc_pr" in self.args: self.use_trunc_pr = True - if 'signed_trunc_pr' in self.args: + if "signed_trunc_pr" in self.args: self.use_trunc_pr = -1 - if 'split' in self.args or 'split3' in self.args: + if "split" in self.args or "split3" in self.args: self.use_split(3) for arg in self.args: - m = re.match('split([0-9]+)', arg) + m = re.match("split([0-9]+)", arg) if m: self.use_split(int(m.group(1))) - if 'raw' in self.args: + if "raw" in self.args: self.always_raw(True) - if 'edabit' in self.args: + if "edabit" in self.args: self.use_edabit(True) - if 'linear_rounds' in self.args: + if "invperm" in self.args: + self.use_invperm(True) + if "linear_rounds" in self.args: self.linear_rounds(True) def disable_memory_warnings(self): self.warn_about_mem.append(False) self.curr_block.warn_about_mem = False + def protect_memory(self, status): + """ Enable or disable memory protection. """ + self._protect_memory = status + + def use_cisc(self): + return self.options.cisc and (not self.prime or self.rabbit_gap()) + + def rabbit_gap(self): + assert self.prime + p = self.prime + logp = int(round(math.log(p, 2))) + return abs(p - 2 ** logp) / p < 2 ** -self.security + + @staticmethod + def read_tapes(schedule): + m = re.search(r"([^/]*)\.mpc", schedule) + if m: + schedule = m.group(1) + if not os.path.exists(schedule): + schedule = "Programs/Schedules/%s.sch" % schedule + + try: + lines = open(schedule).readlines() + except FileNotFoundError: + print( + "%s not found, have you compiled the program?" % schedule, + file=sys.stderr, + ) + sys.exit(1) + + for tapename in lines[2].split(" "): + yield tapename.strip().split(":")[0] + + class Tape: - """ A tape contains a list of basic blocks, onto which instructions are added. """ + """A tape contains a list of basic blocks, onto which instructions are added.""" + def __init__(self, name, program): - """ Set prime p and the initial instructions and registers. """ + """Set prime p and the initial instructions and registers.""" self.program = program - name += '-%d' % program.get_tape_counter() + name += "-%d" % program.get_tape_counter() self.init_names(name) self.init_registers() self.req_tree = self.ReqNode(name) @@ -602,6 +718,7 @@ def __init__(self, name, program): self.singular = True self.free_threads = set() self.loop_breaks = [] + self.warned_about_mem = False class BasicBlock(object): def __init__(self, parent, name, scope, exit_condition=None): @@ -622,6 +739,7 @@ def __init__(self, parent, name, scope, exit_condition=None): self.purged = False self.n_rounds = 0 self.n_to_merge = 0 + self.rounds = Tape.ReqNum() self.warn_about_mem = parent.program.warn_about_mem[-1] def __len__(self): @@ -637,9 +755,9 @@ def set_return(self, previous_block, sub_block): def adjust_return(self): offset = self.sub_block.get_offset(self) self.previous_block.return_address_store.args[1] = offset - + def set_exit(self, condition, exit_true=None): - """ Sets the block which we start from next, depending on the condition. + """Sets the block which we start from next, depending on the condition. (Default is to go to next block in the list) """ @@ -647,34 +765,33 @@ def set_exit(self, condition, exit_true=None): self.exit_block = exit_true for reg in condition.get_used(): reg.can_eliminate = False - + def add_jump(self): - """ Add the jump for this block's exit condition to list of - instructions (must be done after merging) """ + """Add the jump for this block's exit condition to list of + instructions (must be done after merging)""" self.instructions.append(self.exit_condition) - + def get_offset(self, next_block): return next_block.offset - (self.offset + len(self.instructions)) - + def adjust_jump(self): - """ Set the correct relative jump offset """ + """Set the correct relative jump offset""" offset = self.get_offset(self.exit_block) self.exit_condition.set_relative_jump(offset) - #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) def purge(self, retain_usage=True): def relevant(inst): - req_node = Tape.ReqNode('') + req_node = Tape.ReqNode("") req_node.num = Tape.ReqNum() inst.add_usage(req_node) return req_node.num != {} + if retain_usage: - self.usage_instructions = list(filter(relevant, - self.instructions)) + self.usage_instructions = list(filter(relevant, self.instructions)) else: self.usage_instructions = [] if len(self.usage_instructions) > 1000: - print('Retaining %d instructions' % len(self.usage_instructions)) + print("Retaining %d instructions" % len(self.usage_instructions)) del self.instructions self.purged = True @@ -685,13 +802,15 @@ def add_usage(self, req_node): instructions = self.instructions for inst in instructions: inst.add_usage(req_node) - req_node.num['all', 'round'] += self.n_rounds - req_node.num['all', 'inv'] += self.n_to_merge + req_node.num["all", "round"] += self.n_rounds + req_node.num["all", "inv"] += self.n_to_merge + req_node.num += self.rounds def expand_cisc(self): new_instructions = [] - if self.parent.program.options.keep_cisc: - skip = ['LTZ', 'Trunc'] + if self.parent.program.options.keep_cisc is not None: + skip = ["LTZ", "Trunc"] + skip += self.parent.program.options.keep_cisc.split(",") else: skip = [] for inst in self.instructions: @@ -704,38 +823,45 @@ def __str__(self): return self.name def is_empty(self): - """ Returns True if the list of basic blocks is empty. + """Returns True if the list of basic blocks is empty. Note: False is returned even when tape only contains basic blocks with no instructions. However, these are removed when - optimize is called. """ + optimize is called.""" if not self.purged: - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 return self._is_empty - def start_new_basicblock(self, scope=False, name=''): + def start_new_basicblock(self, scope=False, name=""): # use False because None means no scope if scope is False: scope = self.active_basicblock - suffix = '%s-%d' % (name, self.block_counter) + suffix = "%s-%d" % (name, self.block_counter) self.block_counter += 1 - sub = self.BasicBlock(self, self.name + '-' + suffix, scope) + sub = self.BasicBlock(self, self.name + "-" + suffix, scope) self.basicblocks.append(sub) self.active_basicblock = sub self.req_node.add_block(sub) - #print 'Compiling basic block', sub.name + # print 'Compiling basic block', sub.name def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) - + def init_names(self, name): self.name = name - self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' + self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc" + + def __len__(self): + if self.purged: + return self.size + else: + return sum(len(block) for block in self.basicblocks) def purge(self): + self.size = len(self) for block in self.basicblocks: block.purge() - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 del self.basicblocks del self.active_basicblock self.purged = True @@ -745,19 +871,29 @@ def wrapper(self, *args, **kwargs): if self.purged: return return function(self, *args, **kwargs) + return wrapper @unpurged def optimize(self, options): if len(self.basicblocks) == 0: - print('Tape %s is empty' % self.name) + print("Tape %s is empty" % self.name) return if self.if_states: - raise CompilerError('Unclosed if/else blocks') + print("Tracebacks for open blocks:") + for state in self.if_states: + try: + print(util.format_trace(state.caller)) + except AttributeError: + pass + print() + raise CompilerError("Unclosed if/else blocks, see tracebacks above") if self.program.verbose: - print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) + print( + "Processing tape", self.name, "with %d blocks" % len(self.basicblocks) + ) for block in self.basicblocks: al.determine_scope(block, options) @@ -765,41 +901,58 @@ def optimize(self, options): # merge open instructions # need to do this if there are several blocks if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: - for i,block in enumerate(self.basicblocks): + for i, block in enumerate(self.basicblocks): if len(block.instructions) > 0 and self.program.verbose: - print('Processing basic block %s, %d/%d, %d instructions' % \ - (block.name, i, len(self.basicblocks), \ - len(block.instructions))) + print( + "Processing basic block %s, %d/%d, %d instructions" + % ( + block.name, + i, + len(self.basicblocks), + len(block.instructions), + ) + ) # the next call is necessary for allocation later even without merging - merger = al.Merger(block, options, \ - tuple(self.program.to_merge)) + merger = al.Merger(block, options, tuple(self.program.to_merge)) if options.dead_code_elimination: if len(block.instructions) > 1000000: - print('Eliminate dead code...') + print("Eliminate dead code...") merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: block.used_from_scope = util.set_by_id() continue if len(block.instructions) > 1000000: - print('Merging instructions...') + print("Merging instructions...") numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) + if options.verbose: + block.rounds = merger.req_num if merger.counter and self.program.verbose: - print('Block requires', \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.counter.items()))) + print( + "Block requires", + ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.counter.items()) + ), + ) if merger.counter and self.program.verbose: - print('Block requires %s rounds' % \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.rounds.items()))) + print( + "Block requires %s rounds" + % ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.rounds.items()) + ) + ) # free memory merger = None if options.dead_code_elimination: - block.instructions = [x for x in block.instructions if x is not None] + block.instructions = [ + x for x in block.instructions if x is not None + ] if not (options.merge_opens and self.merge_opens): - print('Not merging instructions in tape %s' % self.name) + print("Not merging instructions in tape %s" % self.name) if options.cisc: self.expand_cisc() @@ -824,19 +977,27 @@ def optimize(self, options): reg_counts = self.count_regs() if options.noreallocate: if self.program.verbose: - print('Tape register usage:', dict(reg_counts)) + print("Tape register usage:", dict(reg_counts)) else: if self.program.verbose: - print('Tape register usage before re-allocation:', - dict(reg_counts)) - print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) - print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) - print('Re-allocating...') + print("Tape register usage before re-allocation:", dict(reg_counts)) + print( + "modp: %d clear, %d secret" + % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) + ) + print( + "GF2N: %d clear, %d secret" + % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) + ) + print("Re-allocating...") allocator = al.StraightlineAllocator(REG_MAX, self.program) + def alloc(block): - for reg in sorted(block.used_from_scope, - key=lambda x: (x.reg_type, x.i)): + for reg in sorted( + block.used_from_scope, key=lambda x: (x.reg_type, x.i) + ): allocator.alloc_reg(reg, block.alloc_pool) + def alloc_loop(block): left = deque([block]) while left: @@ -844,73 +1005,84 @@ def alloc_loop(block): alloc(block) for child in block.children: left.append(child) - for i,block in enumerate(reversed(self.basicblocks)): + + for i, block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 1000000: - print('Allocating %s, %d/%d' % \ - (block.name, i, len(self.basicblocks))) + print( + "Allocating %s, %d/%d" % (block.name, i, len(self.basicblocks)) + ) if block.exit_condition is not None: jump = block.exit_condition.get_relative_jump() - if isinstance(jump, int) and jump < 0 and \ - block.exit_block.scope is not None: + if ( + isinstance(jump, int) + and jump < 0 + and block.exit_block.scope is not None + ): alloc_loop(block.exit_block.scope) allocator.process(block.instructions, block.alloc_pool) allocator.finalize(options) if self.program.verbose: - print('Tape register usage:', dict(allocator.usage)) + print("Tape register usage:", dict(allocator.usage)) # offline data requirements if self.program.verbose: - print('Compile offline data requirements...') + print("Compile offline data requirements...") self.req_num = self.req_tree.aggregate() if self.program.verbose: - print('Tape requires', self.req_num) - for req,num in sorted(self.req_num.items()): - if num == float('inf') or num >= 2 ** 32: + print("Tape requires", self.req_num) + for req, num in sorted(self.req_num.items()): + if num == float("inf") or num >= 2**64: num = -1 if req[1] in data_types: self.basicblocks[-1].instructions.append( - Compiler.instructions.use(field_types[req[0]], \ - data_types[req[1]], num, \ - add_to_prog=False)) - elif req[1] == 'input': + Compiler.instructions.use( + field_types[req[0]], data_types[req[1]], num, add_to_prog=False + ) + ) + elif req[1] == "input": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_inp(field_types[req[0]], \ - req[2], num, \ - add_to_prog=False)) - elif req[0] == 'modp': + Compiler.instructions.use_inp( + field_types[req[0]], req[2], num, add_to_prog=False + ) + ) + elif req[0] == "modp": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'gf2n': + Compiler.instructions.use_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "gf2n": self.basicblocks[-1].instructions.append( - Compiler.instructions.guse_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'edabit': + Compiler.instructions.guse_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "edabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(False, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'sedabit': + Compiler.instructions.use_edabit( + False, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "sedabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(True, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'matmul': + Compiler.instructions.use_edabit( + True, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "matmul": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_matmul(*req[1], num, \ - add_to_prog=False)) + Compiler.instructions.use_matmul(*req[1], num, add_to_prog=False) + ) if not self.is_empty(): # bit length requirement - for x in ('p', '2'): + for x in ("p", "2"): if self.req_bit_length[x]: bl = self.req_bit_length[x] if self.program.options.ring: bl = -int(self.program.options.ring) self.basicblocks[-1].instructions.append( - Compiler.instructions.reqbl(bl, - add_to_prog=False)) + Compiler.instructions.reqbl(bl, add_to_prog=False) + ) if self.program.verbose: - print('Tape requires prime bit length', self.req_bit_length['p']) - print('Tape requires galois bit length', self.req_bit_length['2']) + print("Tape requires prime bit length", self.req_bit_length["p"]) + print("Tape requires galois bit length", self.req_bit_length["2"]) @unpurged def expand_cisc(self): @@ -919,93 +1091,104 @@ def expand_cisc(self): @unpurged def _get_instructions(self): - return itertools.chain.\ - from_iterable(b.instructions for b in self.basicblocks) + return itertools.chain.from_iterable(b.instructions for b in self.basicblocks) @unpurged def get_encoding(self): - """ Get the encoding of the program, in human-readable format. """ + """Get the encoding of the program, in human-readable format.""" return [i.get_encoding() for i in self._get_instructions() if i is not None] - + @unpurged def get_bytes(self): - """ Get the byte encoding of the program as an actual string of bytes. """ - return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) - + """Get the byte encoding of the program as an actual string of bytes.""" + return b"".join( + i.get_bytes() for i in self._get_instructions() if i is not None + ) + @unpurged def write_encoding(self, filename): - """ Write the readable encoding to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the readable encoding to a file.""" + print("Writing to", filename) + f = open(filename, "w") for line in self.get_encoding(): - f.write(str(line) + '\n') + f.write(str(line) + "\n") f.close() - + @unpurged def write_str(self, filename): - """ Write the sequence of instructions to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the sequence of instructions to a file.""" + print("Writing to", filename) + f = open(filename, "w") n = 0 for block in self.basicblocks: if block.instructions: - f.write('# %s\n' % block.name) + f.write("# %s\n" % block.name) for line in block.instructions: - f.write('%s # %d\n' % (line, n)) + f.write("%s # %d\n" % (line, n)) n += 1 f.close() - + @unpurged def write_bytes(self, filename=None): - """ Write the program's byte encoding to a file. """ + """Write the program's byte encoding to a file.""" if filename is None: filename = self.outfile - if not filename.endswith('.bc'): - filename += '.bc' - if not 'Bytecode' in filename: - filename = self.program.programs_dir + '/Bytecode/' + filename - print('Writing to', filename) - f = open(filename, 'wb') + if not filename.endswith(".bc"): + filename += ".bc" + if "Bytecode" not in filename: + filename = self.program.programs_dir + "/Bytecode/" + filename + print("Writing to", filename) + f = open(filename, "wb") + h = hashlib.sha256() for i in self._get_instructions(): if i is not None: - f.write(i.get_bytes()) + b = i.get_bytes() + f.write(b) + h.update(b) f.close() - + self.hash = h.digest() + def new_reg(self, reg_type, size=None): return self.Register(reg_type, self, size=size) - + def count_regs(self, reg_type=None): if reg_type is None: return self.reg_counter else: return self.reg_counter[reg_type] - + def __str__(self): return self.name class ReqNum(defaultdict): def __init__(self, init={}): super(Tape.ReqNum, self).__init__(lambda: 0, init) + def __add__(self, other): res = Tape.ReqNum() - for i,count in list(self.items()): - res[i] += count - for i,count in list(other.items()): + for i, count in list(self.items()): + res[i] += count + for i, count in list(other.items()): res[i] += count return res + def __mul__(self, other): res = Tape.ReqNum() for i in self: res[i] = other * self[i] return res + __rmul__ = __mul__ + def set_all(self, value): - if value == float('inf') and self['all', 'inv'] > 0: - print('Going to unknown from %s' % self) + if Program.prog.options.verbose and \ + value == float("inf") and self["all", "inv"] > 0: + print("Going to unknown from %s" % self) res = Tape.ReqNum() for i in self: res[i] = value return res + def max(self, other): res = Tape.ReqNum() for i in self: @@ -1013,82 +1196,105 @@ def max(self, other): for i in other: res[i] = max(self[i], other[i]) return res + def cost(self): - return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ - if req[1] != 'input' and req[0] != 'edabit') + return sum( + num * COST[req[0]][req[1]] + for req, num in list(self.items()) + if req[1] != "input" and req[0] != "edabit" + ) + def pretty(self): - t = lambda x: 'integer' if x == 'modp' else x + def t(x): + return "integer" if x == "modp" else x + res = [] for req, num in self.items(): domain = t(req[0]) - n = '%12.0f' % num - if req[1] == 'input': - res += ['%s %s inputs from player %d' \ - % (n, domain, req[2])] - elif domain.endswith('edabit'): - if domain == 'sedabit': - eda = 'strict edabits' + if num < 0: + num = float('inf') + n = "%12.0f" % num + if req[1] == "input": + res += ["%s %s inputs from player %d" % (n, domain, req[2])] + elif domain.endswith("edabit"): + if domain == "sedabit": + eda = "strict edabits" else: - eda = 'loose edabits' - res += ['%s %s of length %d' % (n, eda, req[1])] - elif domain == 'matmul': - res += ['%s matrix multiplications (%dx%d * %dx%d)' % - (n, req[1][0], req[1][1], req[1][1], req[1][2])] - elif req[0] != 'all': - res += ['%s %s %ss' % (n, domain, req[1])] - if self['all','round']: - res += ['% 12.0f virtual machine rounds' % self['all','round']] + eda = "loose edabits" + res += ["%s %s of length %d" % (n, eda, req[1])] + elif domain == "matmul": + res += [ + "%s matrix multiplications (%dx%d * %dx%d)" + % (n, req[1][0], req[1][1], req[1][1], req[1][2]) + ] + elif req[0] != "all": + res += ["%s %s %ss" % (n, domain, req[1])] + if self["all", "round"]: + res += ["% 12.0f virtual machine rounds" % self["all", "round"]] return res + def __str__(self): - return ', '.join(self.pretty()) + return ", ".join(self.pretty()) + def __repr__(self): return repr(dict(self)) class ReqNode(object): - __slots__ = ['num', 'children', 'name', 'blocks'] + __slots__ = ["num", "children", "name", "blocks"] + def __init__(self, name): self.children = [] self.name = name self.blocks = [] + def aggregate(self, *args): self.num = Tape.ReqNum() for block in self.blocks: block.add_usage(self) - res = reduce(lambda x,y: x + y.aggregate(self.name), - self.children, self.num) + res = reduce( + lambda x, y: x + y.aggregate(self.name), self.children, self.num + ) return res + def increment(self, data_type, num=1): self.num[data_type] += num + def add_block(self, block): self.blocks.append(block) class ReqChild(object): - __slots__ = ['aggregator', 'nodes', 'parent'] + __slots__ = ["aggregator", "nodes", "parent"] + def __init__(self, aggregator, parent): self.aggregator = aggregator self.nodes = [] self.parent = parent + def aggregate(self, name): res = self.aggregator([node.aggregate() for node in self.nodes]) try: n_reps = self.aggregator([1]) - n_rounds = res['all', 'round'] - n_invs = res['all', 'inv'] + n_rounds = res["all", "round"] + n_invs = res["all", "inv"] if (n_invs / n_rounds) * 1000 < n_reps: - print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ - '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)) - except: + print( + self.nodes[0].blocks[0].name, + "blowing up rounds: ", + "(%d / %d) ** 3 < %d" % (n_rounds, n_reps, n_invs), + ) + except Exception: pass return res + def add_node(self, tape, name): new_node = Tape.ReqNode(name) self.nodes.append(new_node) tape.req_node = new_node - def open_scope(self, aggregator, scope=False, name=''): + def open_scope(self, aggregator, scope=False, name=""): child = self.ReqChild(aggregator, self.req_node) self.req_node.children.append(child) - child.add_node(self, '%s-%d' % (name, len(self.basicblocks))) + child.add_node(self, "%s-%d" % (name, len(self.basicblocks))) self.start_new_basicblock(name=name) return child @@ -1096,42 +1302,72 @@ def close_scope(self, outer_scope, parent_req_node, name): self.req_node = parent_req_node self.start_new_basicblock(outer_scope, name) - def require_bit_length(self, bit_length, t='p'): - if t == 'p': + def require_bit_length(self, bit_length, t="p"): + if t == "p": if self.program.prime: - if (bit_length >= self.program.prime.bit_length() - 1): + if bit_length >= self.program.prime.bit_length() - 1: raise CompilerError( - 'required bit length %d too much for %d' % \ - (bit_length, self.program.prime)) - self.req_bit_length[t] = max(bit_length + 1, \ - self.req_bit_length[t]) + "required bit length %d too much for %d" + % (bit_length, self.program.prime) + ) + self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t]) else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) - class Register(object): + @staticmethod + def read_instructions(tapename): + tape = open("Programs/Bytecode/%s.bc" % tapename, "rb") + while tape.peek(): + yield inst_base.ParsedInstruction(tape) + + class _no_truth(object): + __slots__ = [] + + def __bool__(self): + raise CompilerError( + "Cannot derive truth value from register. " + "This is a catch-all error appearing if you try to use a " + "run-time value where the compiler expects a compile-time " + "value, most likely a Python integer. " + "In some cases, you can fix this by using 'compile.py -l'." + ) + + class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. """ - __slots__ = ["reg_type", "program", "absolute_i", "relative_i", \ - "size", "vector", "vectorbase", "caller", \ - "can_eliminate", "duplicates"] - maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 + + __slots__ = [ + "reg_type", + "program", + "absolute_i", + "relative_i", + "size", + "vector", + "vectorbase", + "caller", + "can_eliminate", + "duplicates", + "block", + ] + maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): - """ Creates a new register. - reg_type must be one of those defined in RegType. """ - if Compiler.instructions_base.get_global_instruction_type() == 'gf2n': + """Creates a new register. + reg_type must be one of those defined in RegType.""" + if Compiler.instructions_base.get_global_instruction_type() == "gf2n": if reg_type == RegType.ClearModp: reg_type = RegType.ClearGF2N elif reg_type == RegType.SecretModp: reg_type = RegType.SecretGF2N self.reg_type = reg_type self.program = program + self.block = program.active_basicblock if size is None: size = Compiler.instructions_base.get_global_vector_size() if size is not None and size > self.maximum_size: - raise CompilerError('vector too large: %d' % size) + raise CompilerError("vector too large: %d" % size) self.size = size self.vectorbase = self self.relative_i = 0 @@ -1141,7 +1377,7 @@ def __init__(self, reg_type, program, size=None, i=None): self.i = program.reg_counter[reg_type] program.reg_counter[reg_type] += size else: - self.i = float('inf') + self.i = float("inf") self.vector = [] self.can_eliminate = True self.duplicates = util.set_by_id([self]) @@ -1162,13 +1398,14 @@ def set_size(self, size): if self.size == size: return else: - raise CompilerError('Mismatch of instruction and register size:' - ' %s != %s' % (self.size, size)) + raise CompilerError( + "Mismatch of instruction and register size:" + " %s != %s" % (self.size, size) + ) def set_vectorbase(self, vectorbase): if self.vectorbase is not self: - raise CompilerError('Cannot assign one register' \ - 'to several vectors') + raise CompilerError("Cannot assign one register" "to several vectors") self.relative_i = self.i - vectorbase.i self.vectorbase = vectorbase @@ -1176,7 +1413,7 @@ def _new_by_number(self, i, size=1): return Tape.Register(self.reg_type, self.program, size=size, i=i) def get_vector(self, base=0, size=None): - if size == None: + if size is None: size = self.size if base == 0 and size == self.size: return self @@ -1185,7 +1422,7 @@ def get_vector(self, base=0, size=None): res = self._new_by_number(self.i + base, size=size) res.set_vectorbase(self) self.create_vector_elements() - res.vector = self.vector[base:base+size] + res.vector = self.vector[base : base + size] return res def create_vector_elements(self): @@ -1221,20 +1458,36 @@ def link(self, other): for dup in self.duplicates: dup.duplicates = self.duplicates + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = type(self)(other) + if self.program != other.program: + raise CompilerError( + 'cannot update register with one from another thread') + if other.block in [x.block for x in self.duplicates]: + self.program.start_new_basicblock() + self.link(other) + @property def is_gf2n(self): - return self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.SecretGF2N - + return ( + self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.SecretGF2N + ) + @property def is_clear(self): - return self.reg_type == RegType.ClearModp or \ - self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.ClearInt - - def __bool__(self): - raise CompilerError('Cannot derive truth value from register, ' - "consider using 'compile.py -l'") + return ( + self.reg_type == RegType.ClearModp + or self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.ClearInt + ) def __str__(self): return self.reg_type + str(self.i) diff --git a/Compiler/sorting.py b/Compiler/sorting.py new file mode 100644 index 000000000..c8cb87e89 --- /dev/null +++ b/Compiler/sorting.py @@ -0,0 +1,73 @@ +import itertools +from Compiler import types, library, instructions + +def dest_comp(B): + Bt = B.transpose() + St_flat = Bt.get_vector().prefix_sum() + Tt_flat = Bt.get_vector() * St_flat.get_vector() + Tt = types.Matrix(*Bt.sizes, B.value_type) + Tt.assign_vector(Tt_flat) + return sum(Tt) - 1 + +def reveal_sort(k, D, reverse=False): + """ Sort in place according to "perfect" key. The name hints at the fact + that a random order of the keys is revealed. + + :param k: vector or Array of sint containing exactly :math:`0,\dots,n-1` + in any order + :param D: Array or MultiArray to sort + :param reverse: wether :py:obj:`key` is a permutation in forward or + backward order + + """ + assert len(k) == len(D) + library.break_point() + shuffle = types.sint.get_secure_shuffle(len(k)) + k_prime = k.get_vector().secure_permute(shuffle).reveal() + idx = types.Array.create_from(k_prime) + if reverse: + D.assign_vector(D.get_slice_vector(idx)) + library.break_point() + D.secure_permute(shuffle, reverse=True) + else: + D.secure_permute(shuffle) + library.break_point() + v = D.get_vector() + D.assign_slice_vector(idx, v) + library.break_point() + instructions.delshuffle(shuffle) + +def radix_sort(k, D, n_bits=None, signed=True): + """ Sort in place according to key. + + :param k: keys (vector or Array of sint or sfix) + :param D: Array or MultiArray to sort + :param n_bits: number of bits in keys (int) + :param signed: whether keys are signed (bool) + + """ + assert len(k) == len(D) + bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) + if signed and len(bs) > 1: + bs[-1][:] = bs[-1][:].bit_not() + radix_sort_from_matrix(bs, D) + +def radix_sort_from_matrix(bs, D): + n = len(D) + for b in bs: + assert(len(b) == n) + B = types.sint.Matrix(n, 2) + h = types.Array.create_from(types.sint(types.regint.inc(n))) + @library.for_range(len(bs)) + def _(i): + b = bs[i] + B.set_column(0, 1 - b.get_vector()) + B.set_column(1, b.get_vector()) + c = types.Array.create_from(dest_comp(B)) + reveal_sort(c, h, reverse=False) + @library.if_e(i < len(bs) - 1) + def _(): + reveal_sort(h, bs[i + 1], reverse=True) + @library.else_ + def _(): + reveal_sort(h, D, reverse=True) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py new file mode 100644 index 000000000..ae1aa81ca --- /dev/null +++ b/Compiler/sqrt_oram.py @@ -0,0 +1,804 @@ +from __future__ import annotations + +import math +from abc import abstractmethod +from typing import Any, Generic, Type, TypeVar + +from Compiler import library as lib +from Compiler import util +from Compiler.GC.types import cbit, sbit, sbitint, sbits +from Compiler.program import Program +from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint, + regint, sint, sintbit) +from Compiler.oram import demux_array, get_n_threads + +# Adds messages on completion of heavy computation steps +debug = False +# Finer grained trace of steps that the ORAM performs +# + runtime error checks +# Warning: reveals information and makes the computation insecure +trace = False + +n_threads = 16 +n_parallel = 1024 + +# Avoids any memory allocation if set to False +# Setting to False prevents some optimizations but allows for controlling the ORAMs outside of the main tape +allow_memory_allocation = True + + +def get_n_threads(n_loops): + if n_threads is None: + if n_loops > 2048: + return 8 + else: + return None + else: + return n_threads + + +T = TypeVar("T", sint, sbitint) +B = TypeVar("B", sintbit, sbit) + + +class SqrtOram(Generic[T, B]): + """Oblivious RAM using the "Square-Root" algorithm. + + :param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array). + :param sint value_type: The secret type to use, defaults to sint. + :param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. + :param int period: Leave at None, this parameter is used to recursively pass down the top-level period. + """ + # TODO: Preferably this is an Array of vectors, but this is currently not supported + # One should regard these structures as Arrays where an entry may hold more + # than one value (which is a nice property to have when using the ORAM in + # practise). + shuffle: MultiArray + stash: MultiArray + # A block has an index and data + # `shuffle` and `stash` store the data, + # `shufflei` and `stashi` store the index + shufflei: Array + stashi: Array + + shuffle_used: Array + position_map: PositionMap + + # The size of the ORAM, i.e. how many elements it stores + n: int + # The period, i.e. how many calls can be made to the ORAM before it needs to be refreshed + T: int + # Keep track of how far we are in the period, and coincidentally how large + # the stash is (each access results in a fake or real block being put on + # the stash) + t: cint + + def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None: + global debug, allow_memory_allocation + + # Correctly initialize the shuffle (memory) depending on the type of data + if isinstance(data, MultiArray): + self.shuffle = data + self.n = len(data) + elif isinstance(data, sint): + self.n = math.ceil(len(data) // entry_length) + if (len(data) % entry_length != 0): + raise Exception('Data incorrectly padded.') + self.shuffle = MultiArray( + (self.n, entry_length), value_type=value_type) + self.shuffle.assign_part_vector(data.get_vector()) + else: + raise Exception("Incorrect format.") + + # Only sint is supported + if value_type != sint and value_type != sbitint: + raise Exception("The value_type must be either sint or sbitint") + + # Set derived constants + self.value_type = value_type + self.bit_type: Type[B] = value_type.bit_type + self.index_size = util.log2(self.n) + 1 # +1 because signed + self.index_type = value_type.get_type(self.index_size) + self.entry_length = entry_length + self.size = self.n + + if debug: + lib.print_ln( + 'Initializing SqrtORAM of size %s at depth %s', self.n, k) + + self.shuffle_used = cint.Array(self.n) + # Random permutation on the data + self.shufflei = Array.create_from( + [self.index_type(i) for i in range(self.n)]) + # Calculate the period if not given + # upon recursion, the period should stay the same ("in sync"), + # therefore it can be passed as a constructor parameter + self.T = int(math.ceil( + math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period + if debug and not period: + lib.print_ln('Period set to %s', self.T) + + # Here we allocate the memory for the permutation + # Note that self.shuffle_the_shuffle mutates this field + # Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading + self.permutation = Array.create_from( + [self.index_type(i) for i in range(self.n)]) + # We allow the caller to postpone the initialization of the shuffle + # This is the most expensive operation, and can be done in a thread (only if you know what you're doing) + # Note that if you do not initialize, the ORAM is insecure + if initialize: + # If the ORAM is not initialized with existing data, we can apply + # a small optimization by forgoing shuffling the shuffle, as all + # entries of the shuffle are equal and empty. + if empty_data: + random_shuffle = sint.get_secure_shuffle(self.n) + self.shufflei.secure_permute(random_shuffle) + self.permutation.assign(self.shufflei[:].inverse_permutation()) + if trace: + lib.print_ln('Calculated inverse permutation') + else: + self.shuffle_the_shuffle() + else: + print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.') + # Initialize position map (recursive oram) + self.position_map = PositionMap.create(self.permutation, k + 1, self.T) + + # Initialize stash + self.stash = MultiArray((self.T, entry_length), value_type=value_type) + self.stashi = Array(self.T, value_type=value_type) + self.t = MemValue(cint(0)) + + # Initialize temp variables needed during the computation + self.found_ = self.bit_type.Array(size=self.T) + self.j = MemValue(cint(0, size=1)) + + # To prevent the compiler from recompiling the same code over and over again, we should use @method_block + # However, @method_block requires allocation (of return address), which is not allowed when not in the main thread + # Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread + SqrtOram.shuffle_the_shuffle = lib.method_block(SqrtOram.shuffle_the_shuffle) if allow_memory_allocation else SqrtOram.shuffle_the_shuffle + SqrtOram.refresh = lib.method_block(SqrtOram.refresh) if allow_memory_allocation else SqrtOram.refresh + SqrtOram.reinitialize = lib.method_block(SqrtOram.reinitialize) if allow_memory_allocation else SqrtOram.reinitialize + + @lib.method_block + def access(self, index: T, write: B, *value: T): + global trace,n_parallel + if trace: + @lib.if_e(write.reveal() == 1) + def _(): + lib.print_ln('Writing to secret index %s', index.reveal()) + + @lib.else_ + def __(): + lib.print_ln('Reading from secret index %s', index.reveal()) + + value = self.value_type(value, size=self.entry_length).get_vector( + 0, size=self.entry_length) + index = MemValue(index) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + self.refresh() + + found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in a bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + # Store the stash item into the result if found + # If the item is not in the stash, the result will simple remain 0 + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + # If we are writing, we need to add the value + self.stash[i] += write * access_here * (value - entry) + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('Found item in stash') + + @lib.else_ + def __(): + lib.print_ln('Did not find item in stash') + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. + self.shuffle_used[physical_address] = cbit(True) + + # If the item was not found in the stash + # ...we update the item in the shuffle + self.shuffle[physical_address] += write * \ + found.bit_not() * (value - self.shuffle[physical_address][:]) + # ...and the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + + if trace: + @lib.if_((write * found.bit_not()).reveal()) + def _(): + lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal( + ), self.shuffle[physical_address].reveal(), physical_address) + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + return result + + @lib.method_block + def write(self, index: T, *value: T): + global trace, n_parallel + if trace: + lib.print_ln('Writing to secret index %s', index.reveal()) + + if isinstance(value, tuple) or isinstance(value,list): + value = self.value_type(value, size=self.entry_length) + print(value, type(value)) + elif isinstance(value, self.value_type): + value = self.value_type(*value, size=self.entry_length) + print(value, type(value)) + else: + raise Exception("Cannot handle type of value passed") + print(self.entry_length, value, type(value),len(value)) + value = MemValue(value) + index = MemValue(index) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + self.refresh() + + found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in an bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + # We update the stash value + self.stash[i] += access_here * (value - entry) + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('Found item in stash') + + @lib.else_ + def __(): + lib.print_ln('Did not find item in stash') + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. + self.shuffle_used[physical_address] = cbit(True) + + # If the item was not found in the stash + # ...we update the item in the shuffle + self.shuffle[physical_address] += found.bit_not() * \ + (value - self.shuffle[physical_address][:]) + # ...and the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + + if trace: + @lib.if_(found.bit_not().reveal()) + def _(): + lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal( + ), self.shuffle[physical_address].reveal(), physical_address) + + lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + return result + + @lib.method_block + def read(self, index: T, *value: T): + global debug, trace, n_parallel + if trace: + lib.print_ln('Reading from secret index %s', index.reveal()) + + value = self.value_type(value) + index = MemValue(index) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + if debug: + lib.print_ln('Refreshing SqrtORAM') + lib.print_ln('t=%s according to me', self.t) + + self.refresh() + + found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in a bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check whether the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + lib.check_point() + + # Store the stash item into the result if found + # If the item is not in the stash, the result will simple remain 0 + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: + # @lib.for_range(self.t) + # def _(i): + # lib.print_ln("stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal()) + + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('Found item in stash (found=%s)', found.reveal()) + + @lib.else_ + def __(): + lib.print_ln('Did not find item in stash (found=%s)', found.reveal()) + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. + self.shuffle_used[physical_address] = cbit(True) + + # If the item was not found in the stash + # the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + + if trace: + lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + return result + + __getitem__ = read + __setitem__ = write + + def shuffle_the_shuffle(self) -> None: + """Permute the memory using a newly generated permutation and return + the permutation that would generate this particular shuffling. + + This permutation is needed to know how to map logical addresses to + physical addresses, and is used as such by the postition map.""" + + global trace + # Random permutation on n elements + random_shuffle = sint.get_secure_shuffle(self.n) + if trace: + lib.print_ln('Generated shuffle') + # Apply the random permutation + self.shuffle.secure_permute(random_shuffle) + if trace: + lib.print_ln('Shuffled shuffle') + self.shufflei.secure_permute(random_shuffle) + if trace: + lib.print_ln('Shuffled shuffle indexes') + + lib.check_point() + # Calculate the permutation that would have produced the newly produced + # shuffle order. This can be calculated by regarding the logical + # indexes (shufflei) as a permutation and calculating its inverse, + # i.e. find P such that P([1,2,3,...]) = shufflei. + # this is not necessarily equal to the inverse of the above generated + # random_shuffle, as the shuffle may already be out of order (e.g. when + # refreshing). + self.permutation.assign(self.shufflei[:].inverse_permutation()) + # If shufflei does not contain exactly the indices + # [i for i in range(self.n)], + # the underlying waksman network of 'inverse_permutation' will hang. + if trace: + lib.print_ln('Calculated inverse permutation') + + def refresh(self): + """Refresh the ORAM by reinserting the stash back into the shuffle, and + reshuffling the shuffle. + + This must happen on the T'th (period) accesses to the ORAM.""" + + self.j.write(0) + # Shuffle and emtpy the stash, and store elements back into shuffle + + @lib.for_range_opt(self.n) + def _(i): + @lib.if_(self.shuffle_used[i]) + def _(): + self.shuffle[i] = self.stash[self.j] + self.shufflei[i] = self.stashi[self.j] + self.j += 1 + + # Reset the clock + self.t.write(0) + # Reset shuffle_used + self._reset_shuffle_used() + + # Reinitialize position map + self.shuffle_the_shuffle() + # Note that we skip here the step of "packing" the permutation. + # Since the underlying memory of the position map is already aligned in + # this packed structure, we can simply overwrite the memory while + # maintaining the structure. + self.position_map.reinitialize(*self.permutation) + + def reinitialize(self, *data: T): + # Note that this method is only used during refresh, and as such is + # only called with a permutation as data. + + # The logical addresses of some previous permutation are irrelevant and must be reset + self.shufflei.assign([self.index_type(i) for i in range(self.n)]) + # Reset the clock + self.t.write(0) + # Reset shuffle_used + self._reset_shuffle_used() + + # Note that the self.shuffle is actually a MultiArray + # This structure is preserved while overwriting the values using + # assign_vector + self.shuffle.assign_vector(self.value_type( + data, size=self.n * self.entry_length)) + # Note that this updates self.permutation (see constructor for explanation) + self.shuffle_the_shuffle() + self.position_map.reinitialize(*self.permutation) + + def _reset_shuffle_used(self): + global allow_memory_allocation + if allow_memory_allocation: + self.shuffle_used.assign_all(0) + else: + @lib.for_range_opt(self.n) + def _(i): + self.shuffle_used[i] = cint(0) + + +class PositionMap(Generic[T, B]): + PACK_LOG: int = 3 + PACK: int = 1 << PACK_LOG + + n: int # n in the paper + depth: cint # k in the paper + value_type: Type[T] + + def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None: + self.n = n + self.depth = MemValue(cint(k)) + self.value_type = value_type + self.bit_type = value_type.bit_type + self.index_type = self.value_type.get_type(util.log2(n) + 1) # +1 because signed + + @abstractmethod + def get_position(self, logical_address: _secret, fake: B) -> Any: + """Retrieve the block at the given (secret) logical address.""" + global trace + if trace: + print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)', + self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) + + def reinitialize(self, *permutation: T): + """Reinitialize this PositionMap. + + Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`), + we cannot simply call __init__ on self. Instead, we must take care to + reuse and overwrite the same memory. + """ + ... + + @classmethod + def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = sint) -> PositionMap: + """Creates a new PositionMap. This is the method one should call when + needing a new position map. Depending on the size of the given data, it + will either instantiate a RecursivePositionMap or + a LinearPositionMap.""" + n = len(permutation) + + global debug + if n / PositionMap.PACK <= period: + if debug: + lib.print_ln( + 'Initializing LinearPositionMap at depth %s of size %s', k, n) + res = LinearPositionMap(permutation, value_type, k=k) + else: + if debug: + lib.print_ln( + 'Initializing RecursivePositionMap at depth %s of size %s', k, n) + res = RecursivePositionMap(permutation, period, value_type, k=k) + + return res + + +class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): + + def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k: int = -1) -> None: + PositionMap.__init__(self, len(permutation), k=k) + pack = PositionMap.PACK + + # We pack the permutation into a smaller structure, index with a new permutation + packed_size = int(math.ceil(self.n / pack)) + packed_structure = MultiArray( + (packed_size, pack), value_type=value_type) + for i in range(packed_size): + packed_structure[i] = Array.create_from( + permutation[i*pack:(i+1)*pack]) + + SqrtOram.__init__(self, packed_structure, value_type=value_type, + period=period, entry_length=pack, k=self.depth) + + # Initialize random temp variables needed during the computation + self.block_index_demux: Array = self.bit_type.Array(self.T) + self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK) + + @lib.method_block + def get_position(self, logical_address: T, fake: B) -> _clear: + super().get_position(logical_address, fake) + + pack = PositionMap.PACK + pack_log = PositionMap.PACK_LOG + + # The item at logical_address + # will be in block with index h (block.) + # at position l in block.data (block.data) + program = Program.prog + h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)( + logical_address).right_shift(pack_log, program.bit_length))) + l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) + + global trace + if trace: + print_at_depth(self.depth, '-> logical_address=%s: h=%s, l=%s', logical_address.reveal(), h.reveal(), l.reveal()) + # @lib.for_range(self.t) + # def _(i): + # print_at_depth(self.depth, "stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal()) + + # The resulting physical address + p = MemValue(self.index_type(-1)) + found: B = MemValue(self.bit_type(False)) + + # First we try and retrieve the item from the stash at position stash[h][l] + # Since h and l are secret, we do this by scanning the entire stash + + # First we scan the stash for the block we need + self.block_index_demux.assign_all(0) + @lib.for_range_opt_multithread(get_n_threads(self.T), self.T) + def _(i): + self.block_index_demux[i] = ( self.stashi[i] == h) & self.bit_type(i < self.t) + # We can determine if the 'index' is in the stash by checking the + # block_index_demux array + found = sum(self.block_index_demux) + # Once a block is found, we use the following condition to pick the correct item from that block + demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux) + + # Finally we use the conditions to conditionally write p + @lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type]) + def p_(i): + # We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum + # Therefore we include the check (i < self.t) + return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i // pack< self.t) + p.write(p_()) + + if trace: + @lib.if_e(found.reveal() == 0) + def _(): print_at_depth(self.depth, 'Retrieve shuffle[%s]:', h.reveal()) + @lib.else_ + def __(): + print_at_depth(self.depth, 'Retrieve dummy element from shuffle:') + + # Then we try and retrieve the item from the shuffle (the actual memory) + # Depending on whether we found the item in the stash, we either + # block 'h' in which 'index' resides, or a random block from the shuffle + p_prime = self.position_map.get_position(h, found) + self.shuffle_used[p_prime] = cbit(True) + + # The block retrieved from the shuffle + block_p_prime: Array = self.shuffle[p_prime] + + if trace: + @lib.if_e(found.reveal() == 0) + def _(): + print_at_depth(self.depth, 'Retrieved position from shuffle[%s]=(%s: %s)', + p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal()) + + @lib.else_ + def __(): + print_at_depth(self.depth, 'Retrieved dummy position from shuffle[%s]=(%s: %s)', + p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal()) + + # We add the retrieved block from the shuffle to the stash + self.stash[self.t].assign(block_p_prime[:]) + self.stashi[self.t] = self.shufflei[p_prime] + # Increase t + self.t += 1 + + # if found or not fake + condition: B = self.bit_type(fake.bit_or(found.bit_not())) + # Retrieve l'th item from block + # l is secret, so we must use linear scan + hit = Array.create_from((regint.inc(pack) == l.expand_to_vector( + pack)) & condition.expand_to_vector(pack)) + + @lib.for_range_opt(pack) + def _(i): + p.write((hit[i]).if_else(block_p_prime[i], p)) + + return p.reveal() + + def reinitialize(self, *permutation: T): + SqrtOram.reinitialize(self, *permutation) + + +class LinearPositionMap(PositionMap): + physical: Array + used: Array + + def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None: + PositionMap.__init__(self, len(data), value_type, k=k) + self.physical = data + self.used = self.bit_type.Array(self.n) + + # Initialize random temp variables needed during the computation + self.physical_demux: Array = self.bit_type.Array(self.n) + + @lib.method_block + def get_position(self, logical_address: T, fake: B) -> _clear: + """ + This method corresponds to GetPosBase in the paper. + """ + super().get_position(logical_address, fake) + + global trace + if trace: + @lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal()) + def _(): + lib.runtime_error( + 'logical_address must lie between 0 and self.n - 1') + + fake = MemValue(self.bit_type(fake)) + logical_address = MemValue(logical_address) + + p: MemValue = MemValue(self.index_type(-1)) + done: B = self.bit_type(False) + + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + self.physical_demux.assign_all(0) + + @lib.for_range_opt_multithread(get_n_threads(self.n), self.n) + def condition_i(i): + self.physical_demux[i] = \ + (self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \ + | (fake & self.used[i].bit_not()) + + # In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False) + # We only need once, so we pick the first one we find + @lib.for_range_opt(self.n) + def _(i): + self.physical_demux[i] &= done.bit_not() + done.update(done | self.physical_demux[i]) + + # Retrieve the value from the physical memory obliviously + @lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type]) + def calc_p(i): + return self.physical[i] * self.physical_demux[i] + p.write(calc_p()) + + # Update self.used + self.used.assign(self.used[:] | self.physical_demux[:]) + + if trace: + @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical))) + def _(): + lib.runtime_error( + '%s Did not find requested logical_address in shuffle, something went wrong.', self.depth) + + return p.reveal() + + def reinitialize(self, *data: T): + self.physical.assign_vector(data) + + global allow_memory_allocation + if allow_memory_allocation: + self.used.assign_all(False) + else: + @lib.for_range_opt(self.n) + def _(i): + self.used[i] = self.bit_type(0) + +def print_at_depth(depth: cint, message: str, *kwargs): + lib.print_str('%s', depth) + @lib.for_range(depth) + def _(i): + lib.print_char(' ') + lib.print_char(' ') + lib.print_ln(message, *kwargs) diff --git a/Compiler/types.py b/Compiler/types.py index 917eba752..12d6fb722 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,8 +127,14 @@ def vectorized_operation(self, *args, **kwargs): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise CompilerError('Different vector sizes of operands: %d/%d' - % (self.size, args[0].size)) + if min(args[0].size, self.size) == 1: + size = max(args[0].size, self.size) + self = self.expand_to_vector(size) + args = list(args) + args[0] = args[0].expand_to_vector(size) + else: + raise VectorMismatch('Different vector sizes of operands: %d/%d' + % (self.size, args[0].size)) set_global_vector_size(self.size) try: res = operation(self, *args, **kwargs) @@ -160,7 +166,7 @@ def vectorized_function(cls, *args, **kwargs): size = None if 'size' in kwargs: size = kwargs.pop('size') - if size: + if size is not None: set_global_vector_size(size) try: res = function(cls, *args, **kwargs) @@ -181,7 +187,7 @@ def vectorized_init(*args, **kwargs): if 'size' in kwargs and kwargs['size'] is not None \ and kwargs['size'] != size: raise CompilerError('Mismatch in vector size') - if 'size' in kwargs and kwargs['size']: + if 'size' in kwargs and kwargs['size'] is not None: size = kwargs['size'] if size is not None: set_global_vector_size(size) @@ -214,6 +220,14 @@ def read_mem_operation(self, *args, **kwargs): copy_doc(read_mem_operation, operation) return read_mem_operation +def type_comp(operation): + def type_check(self, other, *args, **kwargs): + if not isinstance(other, (type(self), int, regint, self.clear_type)): + return NotImplemented + return operation(self, other, *args, **kwargs) + copy_doc(type_check, operation) + return type_check + def inputmixed(*args): # helper to cover both cases if isinstance(args[-1], int): @@ -221,7 +235,7 @@ def inputmixed(*args): else: instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),))) -class _number(object): +class _number(Tape._no_truth): """ Number functionality. """ def square(self): @@ -246,7 +260,14 @@ def __mul__(self, other): elif is_one(other): return self else: - return self.mul(other) + try: + return self.mul(other) + except VectorMismatch: + if type(self) != type(other) and 1 in (self.size, other.size): + # try reverse multiplication + return NotImplemented + else: + raise __radd__ = __add__ __rmul__ = __mul__ @@ -266,7 +287,7 @@ def __pow__(self, exp): if i == '1': res *= self return res - else: + elif isinstance(exp, _int): bits = exp.bit_decompose() powers = [self] while len(powers) < len(bits): @@ -274,6 +295,9 @@ def __pow__(self, exp): multiplicands = [b.if_else(p, 1) for b, p in zip(bits, powers)] res = util.tree_reduce(operator.mul, multiplicands) return res + else: + from .mpc_math import pow_fx + return pow_fx(self, exp) def mul_no_reduce(self, other, res_params=None): return self * other @@ -320,7 +344,14 @@ def __abs__(self): def popcnt_bits(bits): return sum(bits) -class _int(object): + def zero_if_not(self, condition): + return condition * self + + def iadd(self, other): + """ Addition assignment. This uses :py:func:`update` internally. """ + self.update(self + other) + +class _int(Tape._no_truth): """ Integer functionality. """ @staticmethod @@ -408,7 +439,7 @@ def half_adder(self, other): def long_one(): return 1 -class _bit(object): +class _bit(Tape._no_truth): """ Binary functionality. """ def bit_xor(self, other): @@ -448,6 +479,10 @@ def carry_out(self, a, b): s = a ^ b return a ^ (s & (self ^ a)) + def cond_swap(self, a, b): + prod = self * (a ^ b) + return a ^ prod, b ^ prod + class _gf2n(_bit): """ :math:`\mathrm{GF}(2^n)` functionality. """ @@ -474,7 +509,7 @@ def bit_xor(self, other): def bit_not(self): return self ^ 1 -class _structure(object): +class _structure(Tape._no_truth): """ Interface for type-dependent container types. """ MemValue = classmethod(lambda cls, value: MemValue(cls.conv(value))) @@ -509,6 +544,8 @@ def Tensor(cls, shape): """ if len(shape) == 1: return Array(shape[0], cls) + elif len(shape) == 2: + return Matrix(*shape, cls) else: return MultiArray(shape, cls) @@ -549,7 +586,8 @@ def input_tensor_from_client(cls, client_id, shape): return res @classmethod - def input_tensor_via(cls, player, content): + def input_tensor_via(cls, player, content=None, shape=None, binary=True, + one_hot=False): """ Input tensor-like data via a player. This overwrites the input file for the relevant player. The following returns an @@ -558,40 +596,77 @@ def input_tensor_via(cls, player, content): M = [[1, 2], [3, 4]] sint.input_tensor_via(0, M) - Make sure to copy ``Player-Data/Input-P-0`` if running + Make sure to copy ``Player-Data/Input-P-0`` or + ``Player-Data/Input-Binary-P-0`` if running on another host. + :param player: player to input via (int) + :param content: nested Python list or numpy array (binary mode only) or + left out if not available + :param shape: shape if content not given + :param binary: binary mode (bool) + :param one_hot: one-hot encoding (bool) + """ if program.curr_tape != program.tapes[0]: raise CompilerError('only available in main thread') - shape = [] - tmp = content - while True: - try: - shape.append(len(tmp)) - tmp = tmp[0] - except: - break - if not program.input_files.get(player, None): - program.input_files[player] = open( - 'Player-Data/Input-P%d-0' % player, 'w') - f = program.input_files[player] - def traverse(content, level): - assert len(content) == shape[level] - if level == len(shape) - 1: - for x in content: - f.write(' ') - f.write(str(x)) + if content is not None: + requested_shape = shape + if binary: + import numpy + content = numpy.array(content) + if issubclass(cls, _fix): + min_k = \ + math.ceil(math.log(abs(content).max(), 2)) + cls.f + 1 + if cls.k < min_k: + raise CompilerError( + "data outside fixed-point range, " + "use 'sfix.set_precision(%d, %d)'" % (cls.f, min_k)) + if binary == 2: + t = numpy.double + else: + t = numpy.single + else: + t = numpy.int64 + if one_hot: + content = numpy.eye(content.max() + 1)[content] + content = content.astype(t) + f = program.get_binary_input_file(player) + f.write(content.tobytes()) + f.flush() + shape = content.shape else: - for x in content: - traverse(x, level + 1) - traverse(content, 0) - f.write('\n') + shape = [] + tmp = content + while True: + try: + shape.append(len(tmp)) + tmp = tmp[0] + except: + break + if not program.input_files.get(player, None): + program.input_files[player] = open( + 'Player-Data/Input-P%d-0' % player, 'w') + f = program.input_files[player] + def traverse(content, level): + assert len(content) == shape[level] + if level == len(shape) - 1: + for x in content: + f.write(' ') + f.write(str(x)) + else: + for x in content: + traverse(x, level + 1) + traverse(content, 0) + f.write('\n') + if requested_shape is not None and \ + list(shape) != list(requested_shape): + raise CompilerError('content contradicts shape') res = cls.Tensor(shape) - res.input_from(player) + res.input_from(player, binary=binary) return res -class _vec(object): +class _vec(Tape._no_truth): def link(self, other): assert len(self.v) == len(other.v) for x, y in zip(self.v, other.v): @@ -726,10 +801,17 @@ def expand_to_vector(self, size=None): assert self.size == 1 res = type(self)(size=size) for i in range(size): - movs(res[i], self) + self.mov(res[i], self) return res -class _clear(_register): +class _arithmetic_register(_register): + """ Arithmetic circuit type. """ + def __init__(self, *args, **kwargs): + if program.options.garbled: + raise CompilerError('functionality only available in arithmetic circuits') + super(_arithmetic_register, self).__init__(*args, **kwargs) + +class _clear(_arithmetic_register): """ Clear domain-dependent type. """ __slots__ = [] mov = staticmethod(movc) @@ -1010,9 +1092,11 @@ def less_than(self, other, bit_length): if bit_length <= 64: return regint(self) < regint(other) else: + sint.require_bit_length(bit_length + 1) diff = self - other - shifted = diff >> (bit_length - 1) - res = regint(shifted & 1) + diff += 1 << bit_length + shifted = diff >> bit_length + res = 1 - regint(shifted & 1) return res def __lt__(self, other): @@ -1063,6 +1147,8 @@ def __eq__(self, other): def __ne__(self, other): return 1 - (self == other) + equal = lambda self, other, *args, **kwargs: self.__eq__(other) + def __lshift__(self, other): """ Clear left shift. @@ -1142,12 +1228,14 @@ def bit_decompose(self, bit_length=None): bit_length = bit_length or program.bit_length return floatingpoint.bits(self, bit_length) + @vectorize def legendre(self): """ Clear Legendre symbol computation. """ res = cint() legendrec(res, self) return res + @vectorize def digest(self, num_bytes): """ Clear hashing (libsodium default). """ res = cint() @@ -1157,11 +1245,11 @@ def digest(self, num_bytes): def print_if(self, string): """ Output if value is non-zero. - :param string: Python string """ + :param string: bytearray """ cond_print_str(self, string) def output_if(self, cond): - cond_print_plain(self.conv(cond), self, cint(0)) + cond_print_plain(self.conv(cond), self, cint(0, size=self.size)) class cgf2n(_clear, _gf2n): @@ -1316,14 +1404,14 @@ def store_in_mem(self, address): @vectorized_classmethod def pop(cls): - """ Pop from stack. """ + """ Pop from stack. Made obsolete by :py:func:`update`. """ res = cls() popint(res) return res @vectorized_classmethod def push(cls, value): - """ Push to stack. + """ Push to stack. Made obsolete by :py:func:`update`. :param value: any convertible type """ pushint(cls.conv(value)) @@ -1630,7 +1718,7 @@ def output_if(self, cond): def _condition(self): if program.options.binary: - from GC.types import cbits + from .GC.types import cbits return cbits.get_type(64)(self) else: return cint(self) @@ -1643,9 +1731,11 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') intoutput(player, self) -class localint(object): +class localint(Tape._no_truth): """ Local integer that must prevented from leaking into the secure computation. Uses regint internally. @@ -1668,7 +1758,14 @@ def output(self): __eq__ = lambda self, other: localint(self._v == other) __ne__ = lambda self, other: localint(self._v != other) -class personal(object): +class personal(Tape._no_truth): + """ Value known to one player. Supports operations with public + values and personal values known to the same player. Can be used + with :py:func:`~Compiler.library.print_ln_to`. + + :param player: player (int) + :param value: cleartext value (cint, cfix, cfloat) or array thereof + """ def __init__(self, player, value): assert value is not NotImplemented assert not isinstance(value, _secret) @@ -1678,10 +1775,63 @@ def __init__(self, player, value): self.player = player self._v = value + @classmethod + def read_int(cls, player): + """ Read integer from + ``Player-Data/Input-Binary-P-`` only on + party :py:obj:`player`. + + :param player: player (int) + :return: personal cint + + """ + tmp = cint() + fixinput(player, tmp, 0, 0) + return cls(player, tmp) + + @classmethod + def read_fix(cls, player, f, k, precision): + """ Read fixed-point value from + ``Player-Data/Input-Binary-P-`` only on + party :py:obj:`player`. + + :param player: player (int) + :param f: fixed-point precision (int) + :param k: fixed-point length (int) + :param precision: input precision (1: single, 2: double) + :return: personal cfix + + """ + assert precision in (1, 2) + tmp = cint() + fixinput(player, tmp, f, precision) + return cls(player, cfix._new(tmp, f=f, k=k)) + def binary_output(self): + """ Write binary output to + ``Player-Data/Binary-Output-P-`` if + supported by underlying type. Player must be known at compile time.""" self._v.binary_output(self.player) - def bit_decompose(self, length): + def reveal_to(self, player): + """ Pass personal value to another player. """ + if isinstance(self._v, Array): + source = self._v[:] + else: + source = self._v + source = cint.conv(source) + res = cint(size=source.size) + sendpersonal(source.size, player, res, self.player, source) + if isinstance(self._v, Array): + res = Array.create_from(res) + return personal(player, res) + + def bit_decompose(self, length=None): + """ Bit decomposition. + + :param length: number of bits + + """ return [personal(self.player, x) for x in self._v.bit_decompose(length)] def _san(self, other): @@ -1692,6 +1842,12 @@ def _san(self, other): def _div_san(self): return self._v.conv((library.get_player_id() == self.player)._v).if_else(self._v, 1) + def __setitem__(self, index, value): + self._san(value) + self._v[index] = value + + __getitem__ = lambda self, index: personal(self.player, self._v[index]) + __add__ = lambda self, other: personal(self.player, self._san(other) + other) __sub__ = lambda self, other: personal(self.player, self._san(other) - other) __mul__ = lambda self, other: personal(self.player, self._san(other) * other) @@ -1776,7 +1932,7 @@ def bit_decompose(self, bit_length): res += x.bit_decompose(64) return res[:bit_length] -class _secret(_register, _secret_structure): +class _secret(_arithmetic_register, _secret_structure): __slots__ = [] mov = staticmethod(set_instruction_type(movs)) @@ -1846,8 +2002,13 @@ def get_random_inverse(cls): @vectorized_classmethod @set_instruction_type def get_random_input_mask_for(cls, player): - res = cls() - inputmask(res, player) + """ Secret random input mask according to security model. + + :return: mask (sint), mask (personal cint) + :param size: vector size (int, default 1) + """ + res = cls(), personal(player, cls.clear_type()) + inputmask(res[0], res[1]._v, player) return res @classmethod @@ -1890,6 +2051,11 @@ def matrix_mul(cls, A, B, n, res_params=None): matmuls(res, A, B, n_rows, n, n_cols) return res + @staticmethod + def _new(self): + # mirror sfix + return self + @no_doc def __init__(self, reg_type, val=None, size=None): if isinstance(val, self.clear_type): @@ -1922,14 +2088,7 @@ def load_other(self, val): r = self.get_dabit() movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit())) elif isinstance(val, sbitvec): - assert(sum(x.n for x in val.v) == self.size) - for val_part, base in zip(val, range(0, self.size, 64)): - left = min(64, self.size - base) - r = self.get_dabit(size=left) - v = regint(size=left) - bitdecint_class(regint((r[1] ^ val_part).reveal()), *v) - part = r[0].bit_xor(v) - vmovs(left, self.get_vector(base, left), part) + movs(self, sint.bit_compose(val)) else: self.load_clear(self.clear_type(val)) @@ -1939,6 +2098,8 @@ def bit_compose(cls, bits): :param bits: iterable of any type convertible to sint """ from Compiler.GC.types import sbits, sbitintvec + if isinstance(bits, sbits): + bits = bits.bit_decompose() bits = list(bits) if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits): if program.use_edabit(): @@ -2007,9 +2168,11 @@ def mul(self, other): size or one size 1 for a value-vector multiplication. :param other: any compatible type """ - if isinstance(other, _secret) and (1 in (self.size, other.size)) \ + if isinstance(other, _register) and (1 in (self.size, other.size)) \ and (self.size, other.size) != (1, 1): x, y = (other, self) if self.size < other.size else (self, other) + if not isinstance(other, _secret): + return y.expand_to_vector(x.size) * x res = type(self)(size=x.size) mulrs(res, x, y) return res @@ -2025,12 +2188,15 @@ def __rsub__(self, other): return self.secret_op(other, subs, submr, subsfi, True) __rsub__.__doc__ = __sub__.__doc__ - @vectorize def __truediv__(self, other): """ Secret field division. :param other: any compatible type """ - return self * (self.clear_type(1) / other) + try: + one = self.clear_type(1, size=other.size) + except AttributeError: + one = self.clear_type(1) + return self * (one / other) @vectorize def __rtruediv__(self, other): @@ -2049,30 +2215,49 @@ def square(self): else: return self * self + @set_instruction_type + def secure_shuffle(self, unit_size=1): + res = type(self)(size=self.size) + secshuffle(res, self, unit_size) + return res + @set_instruction_type @vectorize - def reveal(self): + def reveal(self, check=True): """ Reveal secret value publicly. :rtype: relevant clear type """ res = self.clear_type() - asm_open(res, self) + asm_open(check, res, self) return res @set_instruction_type def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result written to ``Player-Data/Private-Output-P`` :param player: int - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` + """ + mask = self.get_random_input_mask_for(player) + masked = self + mask[0] + res = personal(player, masked.reveal() - mask[1]) + return res + + @set_instruction_type + @vectorize + def raw_right_shift(self, length): + """ Local right shift in supported protocols. + In integer-like protocols, the output is potentially off by one. + + :param length: number of bits """ - masked = self.__class__() - res = personal(player, self.clear_type()) - startprivateoutput(masked, self, player) - stopprivateoutput(res._v, masked.reveal(), player) + res = type(self)() + shrsi(res, self, length) return res + def raw_mod2m(self, m): + return self - (self.raw_right_shift(m) << m) + class sint(_secret, _int): """ @@ -2091,9 +2276,7 @@ class sint(_secret, _int): signed integer in a restricted range, see below. The same holds for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and exponentation (``**``). Modulo only works if the right-hand - operator is a compile-time power of two, and exponentiation only - works if the base is two or if the exponent is a compile-time - integer. + operator is a compile-time power of two. Most non-linear operations require compile-time parameters for bit length and statistical security. They default to the global @@ -2109,9 +2292,17 @@ class sint(_secret, _int): the bit length. :param val: initialization (sint/cint/regint/int/cgf2n or list - thereof or sbits/sbitvec/sfix) + thereof, sbits/sbitvec/sfix, or :py:class:`personal`) :param size: vector size (int), defaults to 1 or size of list + When converting :py:class:`~Compiler.GC.types.sbits`, the result is a + vector of bits, and when converting + :py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values + with bit length equal the length of the input. + + Initializing from a :py:class:`personal` value implies the + relevant party inputting their value securely. + """ __slots__ = [] instruction_type = 'modp' @@ -2166,14 +2357,17 @@ def get_random(cls): return res @vectorized_classmethod - def get_input_from(cls, player): + def get_input_from(cls, player, binary=False): """ Secret input. :param player: public (regint/cint/int) :param size: vector size (int, default 1) """ - res = cls() - inputmixed('int', res, player) + if binary: + return cls(personal.read_int(player)) + else: + res = cls() + inputmixed('int', res, player) return res @vectorized_classmethod @@ -2220,11 +2414,13 @@ def get_raw_input_from(cls, player): @vectorized_classmethod def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ Securely obtain shares of values input by a client. + This uses the triple-based input protocol introduced by + `Damgård et al. `_ :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) - + :returns: list of sint """ # send shares of a triple to client triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) @@ -2256,6 +2452,9 @@ def reveal_to_clients(cls, clients, values): n_clients = clients.length else: n_clients = len(clients) + set_global_vector_size(1) + clients = Array.create_from(regint.conv(clients)) + reset_global_vector_size() @library.for_range(n_clients) def loop_body(i): @@ -2278,6 +2477,17 @@ def read_from_socket(cls, client_id, n=1): else: return res + @vectorized_classmethod + def write_to_socket(cls, client_id, values, + message_type=ClientMessageType.NoType): + """ Send a list of shares and MAC shares to a client socket. + + :param client_id: regint + :param values: list of sint + + """ + writesockets(client_id, message_type, values[0].size, *values) + @vectorize def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType): """ Send only share to socket """ @@ -2308,16 +2518,20 @@ def read_from_file(cls, start, n_items): return stop, shares @staticmethod - def write_to_file(shares): + def write_to_file(shares, position=None): """ Write shares to ``Persistence/Transactions-P.data`` (appending at the end). - :param: shares (list or iterable of sint) + :param shares: (list or iterable of sint) + :param position: start position (int/regint/cint), + defaults to end of file """ for share in shares: assert isinstance(share, sint) assert share.size == 1 - writesharestofile(*shares) + if position is None: + position = -1 + writesharestofile(regint.conv(position), *shares) @vectorized_classmethod def load_mem(cls, address, mem_type=None): @@ -2339,13 +2553,17 @@ def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None): @vectorize_init def __init__(self, val=None, size=None): + from .GC.types import sbitvec if isinstance(val, personal): size = val._v.size super(sint, self).__init__('s', size=size) inputpersonal(size, val.player, self, self.clear_type.conv(val._v)) elif isinstance(val, _fix): super(sint, self).__init__('s', size=val.v.size) - self.load_other(val.v.round(val.k, val.f)) + self.load_other(val.v.round(val.k, val.f, + nearest=val.round_nearest)) + elif isinstance(val, sbitvec): + super(sint, self).__init__('s', val=val, size=val[0].n) else: super(sint, self).__init__('s', val=val, size=size) @@ -2360,6 +2578,7 @@ def __abs__(self): return (self >= 0).if_else(self, -self) @read_mem_value + @type_comp @vectorize def __lt__(self, other, bit_length=None, security=None): """ Secret comparison (signed). @@ -2374,6 +2593,7 @@ def __lt__(self, other, bit_length=None, security=None): return res @read_mem_value + @type_comp @vectorize def __gt__(self, other, bit_length=None, security=None): res = sintbit() @@ -2382,18 +2602,26 @@ def __gt__(self, other, bit_length=None, security=None): security or program.security) return res + @read_mem_value + @type_comp def __le__(self, other, bit_length=None, security=None): return 1 - self.greater_than(other, bit_length, security) + @read_mem_value + @type_comp def __ge__(self, other, bit_length=None, security=None): return 1 - self.less_than(other, bit_length, security) @read_mem_value + @type_comp @vectorize def __eq__(self, other, bit_length=None, security=None): - return floatingpoint.EQZ(self - other, bit_length or program.bit_length, - security or program.security) + return sintbit.conv( + floatingpoint.EQZ(self - other, bit_length or program.bit_length, + security or program.security)) + @read_mem_value + @type_comp def __ne__(self, other, bit_length=None, security=None): return 1 - self.equal(other, bit_length, security) @@ -2552,7 +2780,8 @@ def Norm(self, k, f, kappa=None, simplex_flag=False): @vectorize def int_div(self, other, bit_length=None, security=None): - """ Secret integer division. + """ Secret integer division. Note that the domain bit length + needs to be about four times the bit length. :param other: sint :param bit_length: bit length of input (default: global bit length) @@ -2564,12 +2793,22 @@ def int_div(self, other, bit_length=None, security=None): comparison.Trunc(res, tmp, 2 * k, k, kappa, True) return res + @vectorize + def int_mod(self, other, bit_length=None): + """ Secret integer modulo. Note that the domain bit length + needs to be about four times the bit length. + + :param other: sint + :param bit_length: bit length of input (default: global bit length) + """ + return self - other * self.int_div(other, bit_length=bit_length) + def trunc_zeros(self, n_zeros, bit_length=None, signed=True): bit_length = bit_length or program.bit_length return comparison.TruncZeros(self, bit_length, n_zeros, signed) @staticmethod - def two_power(n): + def two_power(n, size=None): return floatingpoint.two_power(n) def split_to_n_summands(self, length, n): @@ -2587,33 +2826,23 @@ def split_to_two_summands(self, length, get_carry=False): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) - @vectorize - def raw_right_shift(self, length): - res = sint() - shrsi(res, self, length) - return res - - def raw_mod2m(self, m): - return self - (self.raw_right_shift(m) << m) - - @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result potentially written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. - :param player: public integer (int/regint/cint): - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :param player: public integer (int/regint/cint) + :returns: :py:class:`personal` """ - if not util.is_constant(player) or self.size > 1: - secret_mask = sint() - player_mask = cint() - inputmaskreg(secret_mask, player_mask, regint.conv(player)) + if not util.is_constant(player): + secret_mask = sint(size=self.size) + player_mask = cint(size=self.size) + inputmaskreg(secret_mask, player_mask, + regint.conv(player).expand_to_vector(self.size)) return personal(player, - (self + secret_mask).reveal() - player_mask) + (self + secret_mask).reveal(False) - player_mask) else: - return super(sint, self).reveal_to(player) + res = personal(player, self.clear_type(size=self.size)) + privateoutput(self.size, player, res._v, self) + return res def private_division(self, divisor, active=True, dividend_length=None, divisor_length=None): @@ -2672,6 +2901,41 @@ def private_division(self, divisor, active=True, dividend_length=None, return w + @staticmethod + def get_secure_shuffle(n): + res = regint() + gensecshuffle(res, n) + return res + + def secure_permute(self, shuffle, unit_size=1, reverse=False): + res = sint(size=self.size) + applyshuffle(res, self, unit_size, shuffle, reverse) + return res + + def inverse_permutation(self): + if program.use_invperm(): + # If enabled, we use the low-level INVPERM instruction. + # This instruction has only been implemented for a semi-honest two-party environement. + res = sint(size=self.size) + inverse_permutation(res, self) + else: + shuffle = sint.get_secure_shuffle(len(self)) + shuffled = self.secure_permute(shuffle).reveal() + idx = Array.create_from(shuffled) + res = Array.create_from(sint(regint.inc(len(self)))) + res.secure_permute(shuffle, reverse=False) + res.assign_slice_vector(idx, res.get_vector()) + library.break_point() + res = res.get_vector() + return res + + @vectorize + def prefix_sum(self): + """ Prefix sum. """ + res = sint() + prefixsums(res, self) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -2720,7 +2984,9 @@ def __xor__(self, other): elif util.is_zero(other): return self elif util.is_one(other): - return 1 + res = sintbit() + submr(res, cint(1), self) + return res else: return NotImplemented @@ -2733,6 +2999,10 @@ def __rsub__(self, other): else: return super(sintbit, self).__rsub__(other) + __rand__ = __and__ + __rxor__ = __xor__ + __ror__ = __or__ + class sgf2n(_secret, _gf2n): """ Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A @@ -2750,6 +3020,7 @@ class sgf2n(_secret, _gf2n): instruction_type = 'gf2n' clear_type = cgf2n reg_type = 'sg' + long_one = staticmethod(lambda: 1) @classmethod def get_type(cls, length): @@ -2895,10 +3166,11 @@ def bit_decompose_embedding(self): sint.bit_type = sintbit sgf2n.bit_type = sgf2n -class _bitint(object): +class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False + comp_result = staticmethod(lambda x: x) @staticmethod def half_adder(a, b): @@ -3118,12 +3390,16 @@ def wallace_reduction(cls, a, b, c, get_carry=True): del carries[-1] return sums, carries + def expand(self, other): + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + return a, b + def __sub__(self, other): if type(other) == sgf2n: raise CompilerError('Unclear subtraction') - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) from util import bit_not, bit_and, bit_xor + a, b = self.expand(other) n = 1 for x in (a + b): try: @@ -3170,8 +3446,7 @@ def prep_comparison(a, b): a[-1], b[-1] = b[-1], a[-1] def comparison(self, other, const_rounds=False, index=None): - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) + a, b = self.expand(other) self.prep_comparison(a, b) if const_rounds: return self.get_highest_different_bits(a, b, index) @@ -3181,30 +3456,33 @@ def comparison(self, other, const_rounds=False, index=None): def __lt__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 0) + res = util.if_else(not_equal, x, 0) else: - return self.comparison(other, True, 1) + res = self.comparison(other, True, 1) + return self.comp_result(res) def __le__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 1) + res = util.if_else(not_equal, x, x.long_one()) else: - return 1 - self.comparison(other, True, 0) + res = self.comparison(other, True, 0).bit_not() + return self.comp_result(res) def __ge__(self, other): - return 1 - (self < other) + return (self < other).bit_not() def __gt__(self, other): - return 1 - (self <= other) + return (self <= other).bit_not() def __eq__(self, other, bit_length=None, security=None): diff = self ^ other - diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]] - return floatingpoint.KMul(diff_bits) + diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]] + return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y), + diff_bits)) def __ne__(self, other): - return 1 - (self == other) + return (self == other).bit_not() equal = __eq__ @@ -3434,9 +3712,8 @@ class cfix(_number, _structure): scalars = (int, float, regint, cint) @classmethod def set_precision(cls, f, k = None): - """ Set the precision of the integer representation. Note that some - operations are undefined when the precision of :py:class:`sfix` and - :py:class:`cfix` differs. The initial defaults are chosen to + """ Set the precision of the integer representation. + The initial defaults are chosen to allow the best optimization of probabilistic truncation in computation modulo 2^64 (2*k < 64). Generally, 2*k must be at most the integer length for rings and at most m-s-1 for @@ -3494,6 +3771,10 @@ def cfix_to_cint(fix_val): def malloc(size, creator_tape=None): return program.malloc(size, cint, creator_tape=creator_tape) + @classmethod + def free(cls, addr): + return cint.free(addr) + @staticmethod def n_elements(): return 1 @@ -3506,6 +3787,7 @@ def from_int(cls, other): @classmethod def _new(cls, other, k=None, f=None): + assert not isinstance(other, (list, tuple)) res = cls(k=k, f=f) res.v = cint.conv(other) return res @@ -3552,8 +3834,13 @@ def __len__(self): return len(self.v) def __getitem__(self, index): + if isinstance(index, slice): + return [self._new(x, k=self.k, f=self.f) for x in self.v[index]] return self._new(self.v[index], k=self.k, f=self.f) + def get_vector(self): + return self + @vectorize def load_int(self, v): self.v = cint(v) * (2 ** self.f) @@ -3582,18 +3869,28 @@ def size(self): def sizeof(self): return self.size * 4 + @read_mem_value + def parse_type(self, other): + res = parse_type(other, f=self.f, k=self.k) + # check attributes if available + try: + assert res.k == self.k + assert res.f == self.f + except AttributeError: + pass + return res + @vectorize def add(self, other): """ Clear fixed-point addition. :param other: cfix/cint/regint/int """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): - return cfix._new(self.v + other.v) + return cfix._new(self.v + other.v, k=self.k, f=self.f) else: return NotImplemented - @vectorize def mul(self, other): """ Clear fixed-point multiplication. @@ -3602,13 +3899,13 @@ def mul(self, other): return sfix._new(self.v * other, k=self.k, f=self.f) if isinstance(other, (int, regint, cint)): return cfix._new(self.v * cint(other), k=self.k, f=self.f) - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): assert self.f == other.f sgn = cint(1 - 2 * ((self < 0) ^ (other < 0))) absolute = self.v * other.v * sgn val = sgn * (absolute >> self.f) - return cfix._new(val) + return cfix._new(val, k=self.k, f=self.f) elif isinstance(other, sfix): return NotImplemented else: @@ -3625,11 +3922,11 @@ def __sub__(self, other): """ Clear fixed-point subtraction. :param other: cfix/cint/regint/int """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): - return cfix._new(self.v - other.v) + return cfix._new(self.v - other.v, k=self.k, f=self.f) elif isinstance(other, sfix): - return sfix._new(self.v - other.v) + return sfix._new(self.v - other.v, k=self.k, f=self.f) else: raise NotImplementedError @@ -3637,7 +3934,7 @@ def __sub__(self, other): def __neg__(self): """ Clear fixed-point negation. """ # cfix type always has .v - return cfix._new(-self.v) + return cfix._new(-self.v, f=self.f, k=self.k) def __rsub__(self, other): return -self + other @@ -3650,7 +3947,7 @@ def __eq__(self, other): :param other: cfix/cint/regint/int :return: 0/1 :rtype: regint """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): return self.v == other.v elif isinstance(other, sfix): @@ -3661,7 +3958,7 @@ def __eq__(self, other): @vectorize def __lt__(self, other): """ Clear fixed-point comparison. """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): assert self.k == other.k return self.v.less_than(other.v, self.k) @@ -3675,7 +3972,7 @@ def __lt__(self, other): @vectorize def __le__(self, other): """ Clear fixed-point comparison. """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): return 1 - (self > other) elif isinstance(other, sfix): @@ -3686,7 +3983,7 @@ def __le__(self, other): @vectorize def __gt__(self, other): """ Clear fixed-point comparison. """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): return other.__lt__(self) elif isinstance(other, sfix): @@ -3697,7 +3994,7 @@ def __gt__(self, other): @vectorize def __ge__(self, other): """ Clear fixed-point comparison. """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): return 1 - (self < other) elif isinstance(other, sfix): @@ -3708,7 +4005,7 @@ def __ge__(self, other): @vectorize def __ne__(self, other): """ Clear fixed-point comparison. """ - other = parse_type(other) + other = self.parse_type(other) if isinstance(other, cfix): return self.v != other.v elif isinstance(other, sfix): @@ -3725,7 +4022,7 @@ def __truediv__(self, other): """ Clear fixed-point division. :param other: cfix/cint/regint/int """ - other = parse_type(other, self.k, self.f) + other = self.parse_type(other) if isinstance(other, cfix): return cfix._new(library.cint_cint_division( self.v, other.v, self.k, self.f), k=self.k, f=self.f) @@ -3744,18 +4041,18 @@ def __rtruediv__(self, other): """ Fixed-point division. :param other: sfix/sint/cfix/cint/regint/int """ - other = parse_type(other, self.k, self.f) + other = self.parse_type(other) return other / self + @vectorize def print_plain(self): """ Clear fixed-point output. """ print_float_plain(cint.conv(self.v), cint(-self.f), \ cint(0), cint(0), cint(0)) def output_if(self, cond): - cond_print_plain(cint.conv(cond), self.v, cint(-self.f)) + cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size)) - @vectorize def binary_output(self, player=None): """ Write double-precision floating-point number to ``Player-Data/Binary-Output-P-``. @@ -3764,7 +4061,11 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') + set_global_vector_size(self.size) floatoutput(player, self.v, cint(-self.f), cint(0), cint(0)) + reset_global_vector_size() class _single(_number, _secret_structure): """ Representation as single integer preserving the order """ @@ -3784,6 +4085,8 @@ def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) + :returns: list of length ``n`` + """ sint_inputs = cls.int_type.receive_from_client(n, client_id, message_type) @@ -3821,6 +4124,8 @@ def load_mem(cls, address, mem_type=None): def conv(cls, other): if isinstance(other, cls): return other + elif isinstance(other, (list, tuple)): + return type(other)(cls.conv(x) for x in other) else: try: return cls.from_sint(other) @@ -3895,13 +4200,15 @@ def read_from_file(cls, *args, **kwargs): return stop, [cls._new(x) for x in shares] @classmethod - def write_to_file(cls, shares): + def write_to_file(cls, shares, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. - :param: shares (list or iterable of sfix) + :param shares: (list or iterable of sfix) + :param position: start position (int/regint/cint), + defaults to end of file """ - cls.int_type.write_to_file([x.v for x in shares]) + cls.int_type.write_to_file([x.v for x in shares], position) def store_in_mem(self, address): """ Store in memory by public address. """ @@ -3996,6 +4303,7 @@ def get_vector(self): class _fix(_single): """ Secret fixed point type. """ __slots__ = ['v', 'f', 'k'] + is_clear = False def set_precision(cls, f, k = None): cls.f = f @@ -4025,17 +4333,19 @@ def set_precision_from_args(cls, program, adapt_ring=False): elif k is not None: raise CompilerError('need to set fractional precision') if 'nearest' in program.args: - print('Nearest rounding instead of proabilistic ' + print('Nearest rounding instead of probabilistic ' 'for fixed-point computation') cls.round_nearest = True - if adapt_ring and program.options.ring: + if adapt_ring and program.options.ring \ + and 'fix_ring' not in program.args \ + and 2 * cls.k > int(program.options.ring): need = 2 ** int(math.ceil(math.log(2 * cls.k, 2))) if need != int(program.options.ring): print('Changing computation modulus to 2^%d' % need) program.set_ring_size(need) @classmethod - def coerce(cls, other): + def coerce(cls, other, equal_precision=None): if isinstance(other, (_fix, cls.clear_type)): return other else: @@ -4055,7 +4365,7 @@ def conv(cls, other): if isinstance(other, _fix) and (cls.k, cls.f) == (other.k, other.f): return other else: - return cls(other) + return super(_fix, cls).conv(other) @classmethod def _new(cls, other, k=None, f=None): @@ -4094,6 +4404,12 @@ def __init__(self, _v=None, k=None, f=None, size=None): elif isinstance(_v, (MemValue, MemFix)): #this is a memvalue object self.v = type(self)(_v.read()).v + elif isinstance(_v, (list, tuple)): + self.v = self.int_type(list(self.conv(x).v for x in _v)) + elif isinstance(_v, personal): + assert _v._v.f == f + assert _v._v.k == k + self.v = self.int_type(personal(_v.player, _v._v.v)) else: raise CompilerError('cannot convert %s to sfix' % _v) if not isinstance(self.v, self.int_type): @@ -4138,7 +4454,7 @@ def mul(self, other): k = len(bin(abs(v))) - 1 other = self.multipliable(v, k, f, self.size) try: - other = self.coerce(other) + other = self.coerce(other, equal_precision=False) except: return NotImplemented if isinstance(other, (_fix, self.clear_type)): @@ -4213,10 +4529,27 @@ class revealed_fix(self.clear_type): k = self.k return revealed_fix._new(val) + def bit_decompose(self, n_bits=None): + """ Bit decomposition. """ + return self.v.bit_decompose(n_bits or self.k) + + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + assert self.f == other.f + self.v.update(other.v) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` for security considerations of the underlying integer operations. + The secret integer is stored as the :py:obj:`v` member. It supports basic arithmetic (``+, -, *, /``), returning :py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``), @@ -4231,26 +4564,32 @@ class sfix(_fix): :params _v: int/float/regint/cint/sint/sfloat """ int_type = sint + bit_type = sintbit clear_type = cfix + get_type = staticmethod(lambda n: sint) + default_type = sint @vectorized_classmethod - def get_input_from(cls, player): + def get_input_from(cls, player, binary=False): """ Secret fixed-point input. :param player: public (regint/cint/int) :param size: vector size (int, default 1) """ cls.int_type.require_bit_length(cls.k) - v = cls.int_type() - inputmixed('fix', v, cls.f, player) - return cls._new(v) + if binary: + return cls(personal.read_fix(player, cls.f, cls.k, int(binary))) + else: + v = cls.int_type() + inputmixed('fix', v, cls.f, player) + return cls._new(v) @vectorized_classmethod def get_raw_input_from(cls, player): return cls._new(cls.int_type.get_raw_input_from(player)) @vectorized_classmethod - def get_random(cls, lower, upper): + def get_random(cls, lower, upper, symmetric=True): """ Uniform secret random number around centre of bounds. Actual range can be smaller but never larger. @@ -4258,11 +4597,32 @@ def get_random(cls, lower, upper): :param upper: float :param size: vector size (int, default 1) """ + f = cls.f + k = cls.k log_range = int(math.log(upper - lower, 2)) n_bits = log_range + cls.f + gen_range = (2 ** (n_bits) - 1) / 2 ** cls.f + diff = upper - lower + factor = diff / gen_range + real = lambda x: cfix.int_rep(x, f, k) * 2 ** -f + real_range = real(real(factor) * gen_range) average = lower + 0.5 * (upper - lower) - lower = average - 0.5 * 2 ** log_range - return cls._new(cls.int_type.get_random_int(n_bits)) + lower + lower = average - 0.5 * real_range + upper = average + 0.5 * real_range + r = cls._new(cls.int_type.get_random_int(n_bits)) * factor + lower + if symmetric: + lowest = math.floor(lower * 2 ** cls.f) / 2 ** cls.f + highest = math.ceil(upper * 2 ** cls.f) / 2 ** cls.f + if program.verbose: + print('randomness range [%f,%f], ' + 'fringes half the probability' % \ + (lowest, highest)) + return cls.int_type.get_random_bit().if_else(r, -r + 2 * average) + else: + if program.verbose: + print('randomness range [%f,%f], %d bits' % \ + (real(lower), real(lower) + real_range, n_bits)) + return r @classmethod def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None): @@ -4291,8 +4651,21 @@ def dot_product(cls, x, y, res_params=None): def expand_to_vector(self, size): return self._new(self.v.expand_to_vector(size), k=self.k, f=self.f) - def coerce(self, other): - return parse_type(other, k=self.k, f=self.f) + @read_mem_value + def coerce(self, other, equal_precision=True): + res = parse_type(other, k=self.k, f=self.f) + if equal_precision: + # check parameters if available + try: + assert res.k == self.k + assert res.f == self.f + except AttributeError: + pass + return res + + def hard_conv_me(self, cls): + assert cls == sint + return self.v def mul_no_reduce(self, other, res_params=None): assert self.f == other.f @@ -4311,16 +4684,24 @@ def multipliable(v, k, f, size): def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Raw representation possibly written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. :param player: public integer (int/regint/cint) - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` """ return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) + def secure_shuffle(self, *args, **kwargs): + return self._new(self.v.secure_shuffle(*args, **kwargs), + k=self.k, f=self.f) + + def secure_permute(self, *args, **kwargs): + return self._new(self.v.secure_permute(*args, **kwargs), + k=self.k, f=self.f) + + def prefix_sum(self): + return self._new(self.v.prefix_sum(), k=self.k, f=self.f) + class unreduced_sfix(_single): int_type = sint @@ -4459,7 +4840,7 @@ def for_mux(self, other): def __neg__(self): return self._new(-self.v + 2 * util.expand(self.Z, self.v.size)) -class _unreduced_squant(object): +class _unreduced_squant(Tape._no_truth): def __init__(self, v, params, res_params=None, n_summands=1): self.v = v self.params = params @@ -4577,6 +4958,8 @@ class sfloat(_number, _secret_structure): returning :py:class:`sint`. The other operand can be any of sint/cfix/regint/cint/int/float. + This data type only works with arithmetic computation. + :param v: initialization (sfloat/sfix/float/int/sint/cint/regint) """ __slots__ = ['v', 'p', 'z', 's', 'size'] @@ -4675,6 +5058,9 @@ def get_input_from(cls, player): @vectorize_init @read_mem_value def __init__(self, v, p=None, z=None, s=None, size=None): + if program.options.binary: + raise CompilerError( + 'floating-point operations not supported with binary circuits') self.size = get_global_vector_size() if p is None: if isinstance(v, sfloat): @@ -4696,29 +5082,20 @@ def __init__(self, v, p=None, z=None, s=None, size=None): if isinstance(v, int): if not ((v >= 2**(self.vlen-1) and v < 2**(self.vlen)) or v == 0): raise CompilerError('Floating point number malformed: significand') - self.v = sint(v) - else: - self.v = v if isinstance(p, int): if not (p >= -2**(self.plen - 1) and p < 2**(self.plen - 1)): raise CompilerError('Floating point number malformed: exponent %d not unsigned %d-bit integer' % (p, self.plen)) - self.p = sint(p) - else: - self.p = p if isinstance(z, int): if not (z == 0 or z == 1): raise CompilerError('Floating point number malformed: zero bit') - self.z = sint() - ldsi(self.z, z) - else: - self.z = z if isinstance(s, int): if not (s == 0 or s == 1): raise CompilerError('Floating point number malformed: sign') - self.s = sint() - ldsi(self.s, s) - else: - self.s = s + # copying necessary for update to work properly + self.v = sint(v) + self.p = sint(p) + self.z = sint(z) + self.s = sint(s) def __getitem__(self, index): return sfloat(*(x[index] for x in self)) @@ -4929,10 +5306,12 @@ def __ge__(self, other): """ Secret floating-point comparison. """ return 1 - (self < other) + @vectorize def __gt__(self, other): """ Secret floating-point comparison. """ return self.conv(other) < self + @vectorize def __le__(self, other): """ Secret floating-point comparison. """ return self.conv(other) >= self @@ -4981,10 +5360,24 @@ def reveal(self): :return: cfloat """ return cfloat(self.v.reveal(), self.p.reveal(), self.z.reveal(), self.s.reveal()) -class cfloat(object): + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + self.v.update(other.v) + self.p.update(other.p) + self.z.update(other.z) + self.s.update(other.s) + +class cfloat(Tape._no_truth): """ Helper class for printing revealed sfloats. """ __slots__ = ['v', 'p', 'z', 's', 'nan'] + @vectorize_init def __init__(self, v, p=None, z=None, s=None, nan=0): """ Parameters as with :py:class:`sfloat` but public. """ if s is None: @@ -4993,6 +5386,11 @@ def __init__(self, v, p=None, z=None, s=None, nan=0): parts = [cint.conv(x) for x in (v, p, z, s, nan)] self.v, self.p, self.z, self.s, self.nan = parts + @property + def size(self): + return self.v.size + + @vectorize def print_float_plain(self): """ Output. """ print_float_plain(self.v, self.p, self.z, self.s, self.nan) @@ -5032,13 +5430,20 @@ def reveal_to_clients(self, clients): """ self.value_type.reveal_to_clients(clients, [self.get_vector()]) + @staticmethod + def _cmp_fail(*args): + raise CompilerError('equality of data structures is not implemented') + + __eq__ = __ne__ = __le__ = __lt__ = __gt__ = __ge__ = _cmp_fail + class Array(_vectorizable): """ Array accessible by public index. That is, ``a[i]`` works for an array ``a`` and ``i`` being a :py:class:`regint`, :py:class:`cint`, or a Python integer. - :param length: compile-time integer (int) or :py:obj:`None` for unknown length + :param length: compile-time integer (int) or :py:obj:`None` + for unknown length (need to specify :py:obj:`address`) :param value_type: basic type :param address: if given (regint/int), the array will not be allocated @@ -5054,13 +5459,23 @@ class Array(_vectorizable): a[:] += b[:] """ + check_indices = True + @classmethod def create_from(cls, l): """ Convert Python iterator or vector to array. Basic type will be taken from first element, further elements must to be convertible to - that. """ + that. + + :param l: Python iterable or register vector + :returns: :py:class:`Array` of appropriate type containing the contents + of :py:obj:`l` + + """ if isinstance(l, cls): - return l + res = l.same_shape() + res[:] = l[:] + return res if isinstance(l, _number): tmp = l t = type(l) @@ -5081,21 +5496,36 @@ def __init__(self, length, value_type, address=None, debug=None, alloc=True): self.debug = debug self.creator_tape = program.curr_tape self.sink = None - self.check_indices = True if alloc: self.alloc() def alloc(self): - if self.address is None: - self.address = self.value_type.malloc(self.length, - self.creator_tape) + if self._address is None: + try: + self.address = self.value_type.malloc(self.length, + self.creator_tape) + except AttributeError: + raise CompilerError('cannot create Array of %s' % \ + self.value_type) def delete(self): self.value_type.free(self.address) self.address = None - def get_address(self, index): - key = str(index) + @property + def address(self): + if self._address is None: + raise CompilerError('trying access unallocated memory') + return self._address + + @address.setter + def address(self, address): + self._address = address + + def get_address(self, index, size=None): + if isinstance(index, (_secret, _single)): + raise CompilerError('need cleartext index') + key = str(index), size or 1 if self.length is not None: from .GC.types import cbits if isinstance(index, int): @@ -5114,6 +5544,9 @@ def get_address(self, index): # length can be None for single-element arrays length = 0 base = self.address + index * self.value_type.mem_size() + if size is not None and isinstance(base, _register) \ + and not issubclass(self.value_type, _vec): + base = regint._expand_address(base, size) self.address_cache[program.curr_block, key] = \ util.untuplify([base + i * length \ for i in range(n)]) @@ -5128,6 +5561,7 @@ def get_slice(self, index): if index.step == 0: raise CompilerError('slice step cannot be zero') return index.start or 0, \ + index.stop if self.length is None else \ min(index.stop or self.length, self.length), index.step or 1 def __getitem__(self, index): @@ -5160,6 +5594,9 @@ def __setitem__(self, index, value): return self.assign(value, addresses) self._store(value, self.get_address(index)) + def to_array(self): + return self + def get_sub(self, start, stop=None): if stop is None: stop = start @@ -5173,7 +5610,7 @@ def maybe_get(self, condition, index): :param condition: 0/1 (regint/cint/int) :param index: regint/cint/int """ - return condition * self[condition * index] + return self[condition * index].zero_if_not(condition) def maybe_set(self, condition, index, value): """ Change entry if condition is true. @@ -5203,13 +5640,20 @@ def _load(self, address): return self.value_type.load_mem(address) def _store(self, value, address): - self.value_type.conv(value).store_in_mem(address) + tmp = self.value_type.conv(value) + if not isinstance(tmp, _vec) and tmp.size != self.value_type.mem_size(): + raise CompilerError('size mismatch in array assignment') + tmp.store_in_mem(address) def __len__(self): return self.length def total_size(self): - return len(self) * self.value_type.n_elements() + return self.length * self.value_type.n_elements() + + @property + def shape(self): + return [self.length] def __iter__(self): for i in range(self.length): @@ -5231,7 +5675,8 @@ def assign(self, other, base=0): except: pass try: - self.value_type.conv(other).store_in_mem(self.get_address(base)) + other = self.value_type.conv(other) + other.store_in_mem(self.get_address(base, other.size)) if len(self) != None and util.is_constant(base): assert len(self) >= other.size + base except (AttributeError, CompilerError): @@ -5251,16 +5696,26 @@ def assign_all(self, value, use_threads=True, conv=True): """ Assign the same value to all entries. :param value: convertible to basic type """ - if conv: - value = self.value_type.conv(value) - if value.size != 1: - raise CompilerError('cannot assign vector to all elements') - mem_value = MemValue(value) + from Compiler.GC.types import bits + use_vector = util.is_constant(value) and \ + not issubclass(self.value_type, (bits, squant)) + if not use_vector: + if conv: + value = self.value_type.conv(value) + if value.size != 1: + raise CompilerError('cannot assign vector to all elements') + mem_value = MemValue(value) self.address = MemValue.if_necessary(self.address) - n_threads = 8 if use_threads and len(self) > 2**20 else None - @library.for_range_multithread(n_threads, 1024, len(self)) - def f(i): - self[i] = mem_value + n_threads = 8 if use_threads and util.is_constant(self.length) and \ + len(self) > 2**20 else None + @library.multithread(n_threads, self.length) + def _(base, size): + if use_vector: + self.assign_vector(self.value_type(value, size=size), base) + else: + @library.for_range_opt(size) + def _(i): + self[base + i] = mem_value return self def get_vector(self, base=0, size=None): @@ -5269,10 +5724,16 @@ def get_vector(self, base=0, size=None): :param base: starting point (regint/cint/int) :param size: length (compile-time int) """ size = size or self.length - base - return self.value_type.load_mem(self.get_address(base), size=size) + return self.value_type.load_mem(self.get_address(base, size), size=size) get_part_vector = get_vector + def get_reverse_vector(self): + """ Return vector with content in reverse order. """ + size = self.length + address = regint.inc(size, size - 1, -1) + return self.value_type.load_mem(self.address + address, size=size) + def get_part(self, base, size): """ Part array. @@ -5291,13 +5752,21 @@ def get(self, indices): regint.inc(len(indices), self.address, 0) + indices, size=len(indices)) - def get_slice_vector(self, slice): + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 assert len(slice) <= self.total_size() base = regint.inc(len(slice), slice.address, 1, 1) - inc = regint.inc(len(slice), 0, 1, 1, 1) + inc = regint.inc(len(slice), self.address, 1, 1, 1) addresses = slice.value_type.load_mem(base) + inc - return self.value_type.load_mem(self.address + addresses) + return addresses + + def get_slice_vector(self, slice): + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(addresses) def expand_to_vector(self, index, size): """ Create vector from single entry. @@ -5313,7 +5782,15 @@ def expand_to_vector(self, index, size): def get_mem_value(self, index): return MemValue(self[index], self.get_address(index)) - def input_from(self, player, budget=None, raw=False): + def concat(self, other): + """ Concatenate two arrays. """ + assert self.value_type == other.value_type + res = Array(len(self) + len(other), self.value_type) + res.assign_vector(self[:]) + res.assign_vector(other[:], len(self)) + return res + + def input_from(self, player, budget=None, raw=False, **kwargs): """ Fill with inputs from player if supported by type. :param player: public (regint/cint/int) """ @@ -5322,11 +5799,15 @@ def input_from(self, player, budget=None, raw=False): else: input_from = self.value_type.get_input_from try: - self.assign(input_from(player, size=len(self))) - except TypeError: - @library.for_range_opt(len(self), budget=budget) + @library.multithread(None, len(self), + max_size=budget or program.budget) + def _(base, size): + self.assign(input_from(player, size=size, **kwargs), base) + except (TypeError, CompilerError): + print (budget) + @library.for_range_opt(self.length, budget=budget) def _(i): - self[i] = input_from(player) + self[i] = input_from(player, **kwargs) def read_from_file(self, start): """ Read content from ``Persistence/Transactions-P.data``. @@ -5341,11 +5822,14 @@ def read_from_file(self, start): self.assign(shares) return stop - def write_to_file(self): + def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. + + :param position: start position (int/regint/cint), + defaults to end of file """ - self.value_type.write_to_file(list(self)) + self.value_type.write_to_file(list(self), position) def __add__(self, other): """ Vector addition. @@ -5353,14 +5837,12 @@ def __add__(self, other): :param other: vector or container of same length and type that supports operations with type of this array """ if is_zero(other): return self - assert len(self) == len(other) return self.get_vector() + other def __sub__(self, other): """ Vector subtraction. :param other: vector or container of same length and type that supports operations with type of this array """ - assert len(self) == len(other) return self.get_vector() - other def __mul__(self, value): @@ -5384,26 +5866,62 @@ def __pow__(self, value): __radd__ = __add__ __rmul__ = __mul__ + def __iadd__(self, other): + self[:] += other.get_vector() + return self + + def __isub__(self, other): + self[:] -= other.get_vector() + return self + + def __imul__(self, other): + self[:] *= other.get_vector() + return self + + def __itruediv__(self, other): + self[:] /= other.get_vector() + return self + def __neg__(self): return -self.get_vector() def shuffle(self): """ Insecure shuffle in place. """ - if self.value_type == regint: - self.assign(self.get_vector().shuffle()) - else: - @library.for_range(len(self)) - def _(i): - j = regint.get_random(64) % (len(self) - i) - tmp = self[i] - self[i] = self[i + j] - self[i + j] = tmp + self.assign_vector(self.get(regint.inc(len(self)).shuffle())) + + def secure_shuffle(self): + """ Secure shuffle in place according to the security model. + See :py:func:`MultiArray.secure_shuffle` for references. """ + self.assign_vector(self.get_vector().secure_shuffle()) + + def secure_permute(self, *args, **kwargs): + """ Secure permutate in place according to the security model. + See :py:func:`MultiArray.secure_shuffle` for references. + + :param permutation: output of :py:func:`sint.get_secure_shuffle()` + :param reverse: whether to apply inverse (default: False) + + """ + self.assign_vector(self.get_vector().secure_permute(*args, **kwargs)) + + def randomize(self, *args): + """ Randomize array according to data type. + If it is :py:class:`sfix`, the following will sample an + individual uniformly random entry of the array + :py:obj:`M` roughly in the range :math:`[a,b]`:: + + M.randomize(a, b) + + """ + self.assign_vector(self.value_type.get_random(*args, size=len(self))) def reveal(self): """ Reveal the whole array. :returns: Array of relevant clear type. """ - return Array.create_from(x.reveal() for x in self) + res = Array.create_from(self.get_vector().reveal()) + library.break_point() + return res def reveal_list(self): """ Reveal as list. """ @@ -5411,6 +5929,21 @@ def reveal_list(self): reveal_nested = reveal_list + def print_reveal_nested(self, end='\n'): + """ Reveal and print as list. + + :param end: string to print after (default: line break) + """ + if util.is_constant(self.length): + library.print_str('%s' + end, self.get_vector().reveal()) + else: + library.print_str('[') + @library.for_range(self.length - 1) + def _(i): + library.print_str('%s, ', self[i].reveal()) + library.print_str('%s', self[self.length - 1].reveal()) + library.print_str(']' + end) + def reveal_to_binary_output(self, player=None): """ Reveal to binary output if supported by type. @@ -5428,15 +5961,44 @@ def binary_output(self, player=None): """ self.get_vector().binary_output(player) - def sort(self, n_threads=None): + def reveal_to(self, player): + """ Reveal secret array to :py:obj:`player`. + + :param player: public integer (int/regint/cint) + :returns: :py:class:`personal` containing an array + """ + return personal(player, self.create_from(self[:].reveal_to(player)._v)) + + def sort(self, n_threads=None, batcher=False, n_bits=None): """ - Sort in place using Batchers' odd-even merge mergesort - with complexity :math:`O(n (\log n)^2)`. + Sort in place using `radix sort + `_ with complexity + :math:`O(n \log n)` for :py:class:`sint` and :py:class:`sfix`, + and `Batcher's odd-even mergesort + `_ with :math:`O(n (\log + n)^2)` for :py:class:`sfloat`. :param n_threads: number of threads to use (single thread by - default) + default), need to use Batcher's algorithm for several threads + :param batcher: use Batcher's odd-even mergesort in any case + :param n_bits: number of bits in keys (default: global bit length) """ - library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + if batcher or self.value_type.n_elements() > 1 or \ + program.options.binary: + library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + else: + if n_threads or 1 > 1: + raise CompilerError('multi-threaded sorting only implemented ' + 'with Batcher\'s odd-even mergesort') + from . import sorting + sorting.radix_sort(self, self, n_bits=n_bits) + + def Array(self, size): + # compatibility with registers + return Array(size, self.value_type) + + def output_if(self, cond): + library.print_str_if(cond, '%s', self.get_vector()) def __str__(self): return '%s array of length %s at %s' % (self.value_type, len(self), @@ -5449,6 +6011,8 @@ def __str__(self): class SubMultiArray(_vectorizable): """ Multidimensional array functionality. Don't construct this directly, use :py:class:`MultiArray` instead. """ + check_indices = True + def __init__(self, sizes, value_type, address, index, debug=None): self.sizes = tuple(sizes) self.value_type = _get_type(value_type) @@ -5458,7 +6022,6 @@ def __init__(self, sizes, value_type, address, index, debug=None): self.address = None self.sub_cache = {} self.debug = debug - self.check_indices = True if debug: library.print_ln_if(self.address + reduce(operator.mul, self.sizes) * self.value_type.n_elements() > program.allocated_mem[self.value_type.reg_type], 'AOF%d:' % len(self.sizes) + self.debug) @@ -5467,10 +6030,10 @@ def __getitem__(self, index): :param index: public (regint/cint/int) :return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise""" - if util.is_constant(index) and index >= self.sizes[0]: - raise StopIteration if isinstance(index, slice) and index == slice(None): return self.get_vector() + if isinstance(index, int) and index < 0: + index += self.sizes[0] key = program.curr_block, str(index) if key not in self.sub_cache: if util.is_constant(index) and \ @@ -5484,13 +6047,16 @@ def __getitem__(self, index): self.sub_cache[key] = \ Array(self.sizes[1], self.value_type, \ self.address + index * self.sizes[1] * - self.value_type.n_elements(), \ + self.value_type.n_elements() * \ + self.value_type.mem_size(), \ debug=self.debug) else: self.sub_cache[key] = \ SubMultiArray(self.sizes[1:], self.value_type, \ self.address, index, debug=self.debug) - return self.sub_cache[key] + res = self.sub_cache[key] + res.check_indices = self.check_indices + return res def __setitem__(self, index, other): """ Part assignment. @@ -5505,18 +6071,44 @@ def __len__(self): """ Size of top dimension. """ return self.sizes[0] + @property + def shape(self): + return list(self.sizes) + + def __iter__(self): + return (self[i] for i in range(len(self))) + + def to_array(self): + assert self.value_type.n_elements() == 1 and \ + self.value_type.mem_size() == 1 + return Array(self.total_size(), self.value_type, address=self.address) + + def maybe_get(self, condition, index): + return self[condition * index] + + def maybe_set(self, condition, index, value): + for i, x in enumerate(value): + self.maybe_get(condition, index).maybe_set(condition, i, x) + def assign_all(self, value): """ Assign the same value to all entries. :param value: convertible to relevant basic type """ - @library.for_range(self.sizes[0]) - def f(i): - self[i].assign_all(value) + try: + self.to_array().assign_all(value) + except AssertionError: + @library.for_range(self.sizes[0]) + def f(i): + self[i].assign_all(value) return self def total_size(self): return reduce(operator.mul, self.sizes) * self.value_type.n_elements() + def part_size(self): + return reduce(operator.mul, self.sizes[1:]) * \ + self.value_type.n_elements() + def get_vector(self, base=0, size=None): """ Return vector with content. Not implemented for floating-point. @@ -5575,13 +6167,21 @@ def get_slice_vector(self, slice): :param slice: regint array """ + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(self.address + addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(self.address + addresses) + + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) assert len(slice) * part_size <= self.total_size() base = regint.inc(len(slice) * part_size, slice.address, 1, part_size) inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size) addresses = slice.value_type.load_mem(base) * part_size + inc - return self.value_type.load_mem(self.address + addresses) + return addresses def get_addresses(self, *indices): assert self.value_type.n_elements() == 1 @@ -5607,7 +6207,7 @@ def get_addresses(self, *indices): def get_vector_by_indices(self, *indices): """ Vector with potential asterisks. The potential retrieves - all entry where the first dimension index is 0, and the third + all entries where the first dimension index is 0, and the third dimension index is 1:: a.get_vector_by_indices(0, None, 1) @@ -5638,31 +6238,43 @@ def get_part(self, start, size): return MultiArray([size] + list(self.sizes[1:]), self.value_type, address=self[start].address) - def input_from(self, player, budget=None, raw=False): + def concat(self, other): + """ Concatenate two multi-arrays of matching dimension. """ + assert self.sizes[1:] == other.sizes[1:] + assert self.value_type == other.value_type + res = MultiArray((self.sizes[0] + other.sizes[0],) + self.sizes[1:], + self.value_type) + res.assign_vector(self[:]) + res.assign_part_vector(other[:], self.sizes[0]) + return res + + def input_from(self, player, budget=None, raw=False, **kwargs): """ Fill with inputs from player if supported by type. :param player: public (regint/cint/int) """ - budget = budget or Tape.Register.maximum_size - if (self.total_size() < budget) and \ - self.value_type.n_elements() == 1: - if raw or program.always_raw(): - input_from = self.value_type.get_raw_input_from - else: - input_from = self.value_type.get_input_from - self.assign_vector(input_from(player, size=self.total_size())) + if util.is_constant(self.total_size()) and \ + self.value_type.n_elements() == 1 and \ + self.value_type.mem_size() == 1: + self.to_array().input_from(player, budget=budget, raw=raw, **kwargs) else: - @library.for_range_opt(self.sizes[0], - budget=budget / self[0].total_size()) + @library.for_range_opt(self.sizes[0], budget=budget) def _(i): - self[i].input_from(player, budget=budget, raw=raw) + self[i].input_from(player, budget=budget, raw=raw, **kwargs) - def write_to_file(self): + def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. + + :param position: start position (int/regint/cint), + defaults to end of file """ @library.for_range(len(self)) def _(i): - self[i].write_to_file() + if position is None: + my_pos = None + else: + my_pos = position + i * self[i].total_size() + self[i].write_to_file(my_pos) def read_from_file(self, start): """ Read content from ``Persistence/Transactions-P.data``. @@ -5731,6 +6343,22 @@ def iadd(self, other): assert self.sizes == other.sizes self.assign_vector(self.get_vector() + other.get_vector()) + def __iadd__(self, other): + self[:] += other.get_vector() + return self + + def __isub__(self, other): + self[:] -= other.get_vector() + return self + + def __imul__(self, other): + self[:] *= other.get_vector() + return self + + def __itruediv__(self, other): + self[:] /= other.get_vector() + return self + def __mul__(self, other): # legacy function return self.mul(other) @@ -5739,11 +6367,15 @@ def mul(self, other, res_params=None): # legacy function return self.dot(other, res_params) - def dot(self, other, res_params=None): + def dot(self, other, res_params=None, n_threads=None): """ Matrix-matrix and matrix-vector multiplication. :param self: two-dimensional - :param other: Matrix or Array of matching size and type """ + :param other: Matrix or Array of matching size and type + :param n_threads: number of threads (default: all in same thread) + :rtype: Matrix or Array of appropriate size and type + + """ assert len(self.sizes) == 2 if isinstance(other, Array): assert len(other) == self.sizes[1] @@ -5767,14 +6399,25 @@ class t(self.value_type): pass t.params = res_params else: - t = self.value_type + if issubclass(self.value_type, _secret_structure): + t = self.value_type + else: + t = type(self.value_type(0) * other.value_type(0)) res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: try: - res_matrix.assign_vector(self.direct_mul(other)) + self.value_type.direct_matrix_mul + assert self.value_type == other.value_type + max_size = _register.maximum_size // res_matrix.sizes[1] + @library.multithread(n_threads, self.sizes[0], max_size) + def _(base, size): + res_matrix.assign_part_vector( + self.get_part(base, size).direct_mul(other), base) except AttributeError: + assert n_threads is None if max(res_matrix.sizes) > 1000: raise AttributeError() + self.value_type.matrix_mul A = self.get_vector() B = other.get_vector() res_matrix.assign_vector( @@ -5782,17 +6425,17 @@ class t(self.value_type): res_params)) except (AttributeError, AssertionError): # fallback for sfloat etc. - @library.for_range_opt(self.sizes[0]) + @library.for_range_opt_multithread(n_threads, self.sizes[0]) def _(i): try: res_matrix[i] = self.value_type.row_matrix_mul( self[i], other, res_params) - except AttributeError: + except (AttributeError, CompilerError): # fallback for binary circuits - @library.for_range(other.sizes[1]) + @library.for_range_opt(other.sizes[1]) def _(j): res_matrix[i][j] = 0 - @library.for_range(self.sizes[1]) + @library.for_range_opt(self.sizes[1]) def _(k): res_matrix[i][j] += self[i][k] * other[k][j] return res_matrix @@ -5803,6 +6446,8 @@ def _(k): def direct_mul(self, other, reduce=True, indices=None): """ Matrix multiplication in the virtual machine. + Unlike :py:func:`dot`, this only works for sint and sfix, and it + returns a vector instead of a data structure. :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` @@ -5827,6 +6472,7 @@ def direct_mul(self, other, reduce=True, indices=None): other_sizes = other.sizes assert len(other.sizes) == 2 assert self.sizes[1] == other_sizes[0] + assert self.value_type == other.value_type return self.value_type.direct_matrix_mul(self.address, other.address, self.sizes[0], *other_sizes, reduce=reduce, indices=indices) @@ -5844,6 +6490,7 @@ def direct_mul_trans(self, other, reduce=True, indices=None): """ assert len(self.sizes) == 2 assert len(other.sizes) == 2 + assert other.address != None if indices is None: assert self.sizes[1] == other.sizes[1] indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] @@ -5877,74 +6524,78 @@ def direct_trans_mul(self, other, reduce=True, indices=None): self.address, other.address, None, 1, other.sizes[1], reduce=reduce, indices=indices) - def direct_mul_to_matrix(self, other): - """ Matrix multiplication in the virtual machine. + def trans_mul_to(self, other, res, n_threads=None): + """ + Matrix multiplication with the transpose of :py:obj:`self` + in the virtual machine. :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :returns: :py:obj:`Matrix` + :param res: matrix of matching dimension to store result + :param n_threads: number of threads (default: single thread) + """ + assert other.sizes[0] == self.sizes[0] + assert res.sizes[0] == self.sizes[1] + assert res.sizes[1] == other.sizes[1] + assert len(res.sizes) == 2 + @library.for_range_multithread(n_threads, 1, self.sizes[1]) + def _(i): + indices = [regint(i), regint.inc(self.sizes[0])] + indices += [regint.inc(i) for i in other.sizes] + res[i] = self.direct_trans_mul(other, indices=indices) + + def mul_trans_to(self, other, res, n_threads=None): + """ + Matrix multiplication with the transpose of :py:obj:`other` + in the virtual machine. + :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param res: matrix of matching dimension to store result + :param n_threads: number of threads (default: single thread) """ + assert other.sizes[1] == self.sizes[1] + assert res.sizes[0] == self.sizes[0] + assert res.sizes[1] == other.sizes[0] + assert len(res.sizes) == 2 + @library.for_range_multithread(n_threads, 1, self.sizes[0]) + def _(i): + indices = [regint(i), regint.inc(self.sizes[1])] + indices += [regint.inc(i) for i in reversed(other.sizes)] + res[i] = self.direct_mul_trans(other, indices=indices) + + def direct_mul_to_matrix(self, other): + # Obsolete. Use dot(). res = self.value_type.Matrix(self.sizes[0], other.sizes[1]) res.assign_vector(self.direct_mul(other)) return res - def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True, - res=None): - assert len(self.sizes) == 2 - assert len(other.sizes) == 2 - if res is None: - if reduce: - res_matrix = Matrix(n_rows, n_columns, self.value_type) - else: - res_matrix = Matrix(n_rows, n_columns, \ - self.value_type.unreduced_type) - else: - res_matrix = res - @library.for_range_opt(n_rows) - def _(i): - @library.for_range_opt(n_columns) - def _(j): - col = column(other, j) - r = row(self, i) - if reduce: - res_matrix[i][j] = self.value_type.dot_product(r, col) - else: - entry = self.value_type.unreduced_dot_product(r, col) - res_matrix[i][j] = entry - return res_matrix - def plain_mul(self, other, res=None): - """ Alternative matrix multiplication. - - :param self: two-dimensional - :param other: two-dimensional container of matching type and size """ - assert other.sizes[0] == self.sizes[1] - return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \ - other.sizes[1], \ - lambda x, j: [x[k][j] for k in range(len(x))], - res=res) + raise CompilerError('Deprecated functionality. Use dot()') def mul_trans(self, other): """ Matrix multiplication with transpose of :py:obj:`other`. :param self: two-dimensional - :param other: two-dimensional container of matching type and size """ - assert other.sizes[1] == self.sizes[1] - return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \ - other.sizes[0], lambda x, j: x[j]) + :param other: two-dimensional container of matching type and size + :return: Matrix of matching type and size + + """ + res = Matrix(self.sizes[0], other.sizes[0], self.value_type) + self.mul_trans_to(other, res) + return res - def trans_mul(self, other, reduce=True, res=None): + def trans_mul(self, other): """ Matrix multiplication with transpose of :py:obj:`self` :param self: two-dimensional - :param other: two-dimensional container of matching type and size """ - assert other.sizes[0] == self.sizes[0] - return self.budget_mul(other, self.sizes[1], \ - lambda x, j: [x[k][j] for k in range(len(x))], \ - other.sizes[1], \ - lambda x, j: [x[k][j] for k in range(len(x))], - reduce=reduce, res=res) + :param other: two-dimensional container of matching type and size + :return: Matrix of matching type and size + + """ + res = Matrix(self.sizes[1], other.sizes[1], self.value_type) + self.trans_mul_to(other, res) + return res def parallel_mul(self, other): assert self.sizes[1] == other.sizes[0] @@ -5974,13 +6625,15 @@ def transpose(self): res = Matrix(self.sizes[1], self.sizes[0], self.value_type) library.break_point() if self.value_type.n_elements() == 1: - @library.for_range_opt(self.sizes[0]) - def _(j): - res.set_column(j, self[j][:]) + nr = self.sizes[1] + nc = self.sizes[0] + a = regint.inc(nr * nc, 0, nr, 1, nc) + b = regint.inc(nr * nc, 0, 1, nc) + res[:] = self.value_type.load_mem(self.address + a + b) else: - @library.for_range_opt(self.sizes[1]) + @library.for_range_opt(self.sizes[1], budget=100) def _(i): - @library.for_range_opt(self.sizes[0]) + @library.for_range_opt(self.sizes[0], budget=100) def _(j): res[i][j] = self[j][i] library.break_point() @@ -5992,6 +6645,75 @@ def trace(self): assert self.sizes[0] == self.sizes[1] return sum(self[i][i] for i in range(self.sizes[0])) + def diag(self): + """ Matrix diagonal. """ + assert len(self.sizes) == 2 + assert self.sizes[0] == self.sizes[1] + n = self.sizes[0] + return self.array.get(regint.inc(n, 0, n + 1)) + + def secure_shuffle(self): + """ Securely shuffle rows (first index). This uses the algorithm in + Section 4.3 of `Keller and Scholl + `_ or Section 3.2 of + `Asharov et al. `_ if applicable. + """ + self.assign_vector(self.get_vector().secure_shuffle(self.part_size())) + + def secure_permute(self, permutation, reverse=False): + """ Securely permute rows (first index). See + :py:func:`secure_shuffle` for references. + + :param permutation: output of :py:func:`sint.get_secure_shuffle()` + :param reverse: whether to apply inverse (default: False) + + """ + self.assign_vector(self.get_vector().secure_permute( + permutation, self.part_size(), reverse)) + + def sort(self, key_indices=None, n_bits=None): + """ Sort sub-arrays (different first index) in place. + This uses `radix sort `_. + + :param key_indices: indices to sorting keys, for example + ``(1, 2)`` to sort three-dimensional array ``a`` by keys + ``a[*][1][2]``. Default is ``(0, ..., 0)`` of correct length. + :param n_bits: number of bits in keys (default: global bit length) + + """ + if program.options.binary: + assert key_indices is None + assert len(self.sizes) == 2 + library.loopy_odd_even_merge_sort(self) + return + if key_indices is None: + key_indices = (0,) * (len(self.sizes) - 1) + key_indices = (None,) + util.tuplify(key_indices) + from . import sorting + keys = self.get_vector_by_indices(*key_indices) + sorting.radix_sort(keys, self, n_bits=n_bits) + + def randomize(self, *args, n_threads=None): + """ Randomize according to data type. + If it is :py:class:`sfix`, the following will sample an + individual uniformly random entry of the multi-array + :py:obj:`M` roughly in the range :math:`[a,b]`:: + + M.randomize(a, b) + + """ + @library.multithread(n_threads, self.total_size(), + max_size=program.budget) + def _(base, size): + self.assign_vector( + self.value_type.get_random(*args, size=size), base=base) + + def reveal(self): + """ Reveal to :py:obj:`MultiArray` of same shape. """ + res = MultiArray(self.sizes, self.value_type.clear_type) + res[:] = self.get_vector().reveal() + return res + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6007,6 +6729,22 @@ def f(sizes): return [f(sizes[1:]) for i in range(sizes[0])] return f(self.sizes) + def print_reveal_nested(self, end='\n'): + """ Reveal and print as nested list. + + :param end: string to print after (default: line break) + """ + if util.is_constant(self.total_size()) and \ + self.total_size() < program.budget: + library.print_str('%s' + end, self.reveal_nested()) + else: + library.print_str('[') + @library.for_range(len(self) - 1) + def _(i): + self[i].print_reveal_nested(end=', ') + self[len(self) - 1].print_reveal_nested(end='') + library.print_str(']' + end) + def reveal_to_binary_output(self, player=None): """ Reveal to binary output if supported by type. @@ -6042,13 +6780,17 @@ class MultiArray(SubMultiArray): a[2][:] = a[0][:] * a[1][:] """ + @staticmethod + def disable_index_checks(): + SubMultiArray.check_indices = False + def __init__(self, sizes, value_type, debug=None, address=None, alloc=True): if isinstance(address, Array): self.array = address else: self.array = Array(reduce(operator.mul, sizes), \ value_type, address=address, alloc=alloc) - SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \ + SubMultiArray.__init__(self, sizes, value_type, self.array._address, 0, debug=debug) if len(sizes) < 2: raise CompilerError('Use Array') @@ -6079,6 +6821,43 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + @staticmethod + def create_from(rows): + rows = list(rows) + if isinstance(rows[0], (list, tuple, Array)): + t = type(rows[0][0]) + else: + t = type(rows[0]) + if t != sfix: + for row in rows: + if isinstance(row, sfix) or \ + (isinstance(row, Array) and row.value_type == sfix): + raise CompilerError( + 'accidental shortening by creating matrix') + res = Matrix(len(rows), len(rows[0]), t) + for i in range(len(rows)): + res[i].assign(rows[i]) + return res + + def get_column(self, index): + """ Get column as vector. + + :param index: regint/cint/int + """ + assert self.value_type.n_elements() == 1 + addresses = regint.inc(self.sizes[0], self.address + index, + self.sizes[1]) + return self.value_type.load_mem(addresses) + + def get_columns(self): + return (self.get_column(i) for i in range(self.sizes[1])) + + def get_column_by_row_indices(self, rows, column): + assert self.value_type.n_elements() == 1 + addresses = rows * self.sizes[1] + \ + regint.inc(len(rows), self.address + column, 0) + return self.value_type.load_mem(addresses) + def set_column(self, index, vector): """ Change column. @@ -6090,6 +6869,20 @@ def set_column(self, index, vector): self.sizes[1]) self.value_type.conv(vector).store_in_mem(addresses) + def concat_columns(self, other): + """ Concatenate two matrices by columns. """ + assert self.sizes[0] == other.sizes[0] + assert self.value_type == other.value_type + res = Matrix(self.sizes[0], self.sizes[1] + other.sizes[1], + self.value_type) + @library.for_range(self.sizes[1]) + def _(i): + res.set_column(i, self.get_column(i)) + @library.for_range(other.sizes[1]) + def _(i): + res.set_column(self.sizes[1] + i, other.get_column(i)) + return res + class VectorArray(object): def __init__(self, length, value_type, vector_size, address=None): self.array = Array(length * vector_size, value_type, address) @@ -6215,7 +7008,10 @@ def read(self): :return: relevant basic type instance """ self.check() if program.curr_block != self.last_write_block: - self.register = self.value_type.load_mem(self.address) + from Compiler.GC.types import sbitvec + self.register = self.value_type.load_mem( + self.address, size=self.size \ + if issubclass(self.value_type, (_register, sbitvec)) else None) self.last_write_block = program.curr_block return self.register @@ -6226,7 +7022,11 @@ def write(self, value): self.check() if isinstance(value, MemValue): value = value.read() - value = self.value_type.conv(value) + try: + value = self.value_type.conv(value) + except: + raise CompilerError('Cannot store %s as MemValue of %s' % \ + (type(value), self.value_type)) if value.size != self.size: raise CompilerError('size mismatch') self.register = value @@ -6264,6 +7064,7 @@ def reveal(self): if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs) bit_and = lambda self,other: self.read().bit_and(other) + bit_not = lambda self: self.read().bit_not() def expand_to_vector(self, size=None): if program.curr_block == self.last_write_block: diff --git a/Compiler/util.py b/Compiler/util.py index aa491e422..6c3c3ce59 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -116,6 +116,11 @@ def round_to_int(x): return x.round_to_int() def tree_reduce(function, sequence): + try: + return sequence.tree_reduce(function) + except AttributeError: + pass + sequence = list(sequence) assert len(sequence) > 0 n = len(sequence) @@ -233,6 +238,9 @@ def mem_size(x): except AttributeError: return 1 +def find_in_dict(d, v): + return list(d.keys())[list(d.values()).index(v)] + class set_by_id(object): def __init__(self, init=[]): self.content = {} @@ -257,6 +265,9 @@ def add(self, value): def pop(self): return self.content.popitem()[1] + def remove(self, value): + del self.content[id(value)] + def __ior__(self, values): for value in values: self.add(value) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..79dc702bb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,151 @@ +############################################################################### +# Build this stage for a build environment, e.g.: # +# # +# docker build --tag mpspdz:buildenv --target buildenv . # +# # +# The above is equivalent to: # +# # +# docker build --tag mpspdz:buildenv \ # +# --target buildenv \ # +# --build-arg arch=native \ # +# --build-arg cxx=clang++-11 \ # +# --build-arg use_ntl=0 \ # +# --build-arg prep_dir="Player-Data" \ # +# --build-arg ssl_dir="Player-Data" # +# --build-arg cryptoplayers=0 # +# # +# To build for an x86-64 architecture, with g++, NTL (for HE), custom # +# prep_dir & ssl_dir, and to use encrypted channels for 4 players: # +# # +# docker build --tag mpspdz:buildenv \ # +# --target buildenv \ # +# --build-arg arch=x86-64 \ # +# --build-arg cxx=g++ \ # +# --build-arg use_ntl=1 \ # +# --build-arg prep_dir="/opt/prepdata" \ # +# --build-arg ssl_dir="/opt/ssl" # +# --build-arg cryptoplayers=4 . # +# # +# To work in a container to build different machines, and compile programs: # +# # +# docker run --rm -it mpspdz:buildenv bash # +# # +# Once in the container, build a machine and compile a program: # +# # +# $ make replicated-ring-party.x # +# $ ./compile.py -R 64 tutorial # +# # +############################################################################### +FROM python:3.10.3-bullseye as buildenv + +RUN apt-get update && apt-get install -y --no-install-recommends \ + automake \ + build-essential \ + clang-11 \ + cmake \ + git \ + libboost-dev \ + libboost-thread-dev \ + libclang-dev \ + libntl-dev \ + libsodium-dev \ + libssl-dev \ + libtool \ + m4 \ + texinfo \ + yasm \ + vim \ + gdb \ + valgrind \ + && rm -rf /var/lib/apt/lists/* + +# mpir +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/include/* /usr/local/include/ +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/lib/* /usr/local/lib/ +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/share/info/* /usr/local/share/info/ + +ENV MP_SPDZ_HOME /usr/src/MP-SPDZ +WORKDIR $MP_SPDZ_HOME + +RUN pip install --upgrade pip ipython + +COPY . . + +ARG arch=native +ARG cxx=clang++-11 +ARG use_ntl=0 +ARG prep_dir="Player-Data" +ARG ssl_dir="Player-Data" + +RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \ + && echo "CXX = ${cxx}" >> CONFIG.mine \ + && echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \ + && echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \ + && echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \ + >> CONFIG.mine \ + && mkdir -p $prep_dir $ssl_dir \ + && echo "PREP_DIR = '-DPREP_DIR=\"${prep_dir}/\"'" >> CONFIG.mine \ + && echo "SSL_DIR = '-DSSL_DIR=\"${ssl_dir}/\"'" >> CONFIG.mine + +# ssl keys +ARG cryptoplayers=0 +ENV PLAYERS ${cryptoplayers} +RUN ./Scripts/setup-ssl.sh ${cryptoplayers} ${ssl_dir} + +RUN make boost libote + +############################################################################### +# Use this stage to a build a specific virtual machine. For example: # +# # +# docker build --tag mpspdz:shamir \ # +# --target machine \ # +# --build-arg machine=shamir-party.x \ # +# --build-arg gfp_mod_sz=4 . # +# # +# The above will build shamir-party.x with 256 bit length. # +# # +# If no build arguments are passed (via --build-arg), mascot-party.x is built # +# with the default 128 bit length. # +############################################################################### +FROM buildenv as machine + +ARG machine="mascot-party.x" + +ARG gfp_mod_sz=2 + +RUN echo "MOD = -DGFP_MOD_SZ=${gfp_mod_sz}" >> CONFIG.mine + +RUN make clean && make ${machine} && cp ${machine} /usr/local/bin/ + + +################################################################################ +# This is the default stage. Use it to compile a high-level program. # +# By default, tutorial.mpc is compiled with --field=64 bits. # +# # +# docker build --tag mpspdz:mascot-tutorial \ # +# --build-arg src=tutorial \ # +# --build-arg compile_options="--field=64" . # +# # +# Note that build arguments from previous stages can also be passed. For # +# instance, building replicated-ring-party.x, for 3 crypto players with custom # +# PREP_DIR and SSL_DIR, and compiling tutorial.mpc with --ring=64: # +# # +# docker build --tag mpspdz:replicated-ring \ # +# --build-arg machine=replicated-ring-party.x \ # +# --build-arg prep_dir=/opt/prep \ # +# --build-arg ssl_dir=/opt/ssl \ # +# --build-arg cryptoplayers=3 \ # +# --build-arg compile_options="--ring=64" . # +# # +# Test it: # +# # +# docker run --rm -it mpspdz:replicated-ring ./Scripts/ring.sh tutorial # +################################################################################ +FROM machine as program + +ARG src="tutorial" +ARG compile_options="--field=64" +RUN ./compile.py ${compile_options} ${src} +RUN mkdir -p Player-Data \ + && echo 1 2 3 4 > Player-Data/Input-P0-0 \ + && echo 1 2 3 4 > Player-Data/Input-P1-0 diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index 23f81b9ef..ecf7011ba 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -22,4 +22,5 @@ int main() generate_mac_keys>(key, 2, prefix); make_mult_triples>(key, 2, 1000, false, prefix); make_inverse>(key, 2, 1000, false, prefix); + P256Element::finish(); } diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 8437f39d2..1ff3273f8 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -14,7 +14,14 @@ void P256Element::init() curve = EC_GROUP_new_by_curve_name(NID_secp256k1); assert(curve != 0); auto modulus = EC_GROUP_get0_order(curve); - Scalar::init_field(BN_bn2dec(modulus), false); + auto mod = BN_bn2dec(modulus); + Scalar::init_field(mod, false); + free(mod); +} + +void P256Element::finish() +{ + EC_GROUP_free(curve); } P256Element::P256Element() @@ -29,7 +36,7 @@ P256Element::P256Element(const Scalar& other) : { BIGNUM* exp = BN_new(); BN_dec2bn(&exp, bigint(other).get_str().c_str()); - assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0); + assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0); BN_free(exp); } @@ -38,10 +45,15 @@ P256Element::P256Element(word other) : { BIGNUM* exp = BN_new(); BN_dec2bn(&exp, to_string(other).c_str()); - assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0); + assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0); BN_free(exp); } +P256Element::~P256Element() +{ + EC_POINT_free(point); +} + P256Element& P256Element::operator =(const P256Element& other) { assert(EC_POINT_copy(point, other.point) != 0); @@ -56,7 +68,11 @@ void P256Element::check() P256Element::Scalar P256Element::x() const { BIGNUM* x = BN_new(); +#if OPENSSL_VERSION_MAJOR >= 3 + assert(EC_POINT_get_affine_coordinates(curve, point, x, 0, 0) != 0); +#else assert(EC_POINT_get_affine_coordinates_GFp(curve, point, x, 0, 0) != 0); +#endif char* xx = BN_bn2dec(x); Scalar res((bigint(xx))); OPENSSL_free(xx); @@ -95,7 +111,7 @@ bool P256Element::operator ==(const P256Element& other) const return not cmp; } -void P256Element::pack(octetStream& os) const +void P256Element::pack(octetStream& os, int) const { octet* buffer; size_t length = EC_POINT_point2buf(curve, point, @@ -103,9 +119,10 @@ void P256Element::pack(octetStream& os) const assert(length != 0); os.store_int(length, 8); os.append(buffer, length); + free(buffer); } -void P256Element::unpack(octetStream& os) +void P256Element::unpack(octetStream& os, int) { size_t length = os.get_int(8); assert( diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index e426bade9..bd005c840 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -22,20 +22,23 @@ class P256Element : public ValueInterface EC_POINT* point; public: - typedef void next; + typedef P256Element next; typedef void Square; static const true_type invertible; static int size() { return 0; } + static int length() { return 256; } static string type_string() { return "P256"; } static void init(); + static void finish(); P256Element(); P256Element(const P256Element& other); P256Element(const Scalar& other); P256Element(word other); + ~P256Element(); P256Element& operator=(const P256Element& other); @@ -57,8 +60,8 @@ class P256Element : public ValueInterface bool is_zero() { return *this == P256Element(); } void add(octetStream& os) { *this += os.get(); } - void pack(octetStream& os) const; - void unpack(octetStream& os); + void pack(octetStream& os, int = -1) const; + void unpack(octetStream& os, int = -1); octetStream hash(size_t n_bytes) const; diff --git a/ECDSA/README.md b/ECDSA/README.md index 6307a91e6..7cb014ce8 100644 --- a/ECDSA/README.md +++ b/ECDSA/README.md @@ -24,8 +24,8 @@ The following binaries have been used for the paper: All binaries offer the same interface. With MASCOT for example, run the following: ``` -./mascot-ecsda-party.x -p 0 [-N ] [-h ] [-D] [] -./mascot-ecsda-party.x -p 1 [-N ] [-h ] [-D] [] +./mascot-ecdsa-party.x -p 0 [-N ] [-h ] [-D] [] +./mascot-ecdsa-party.x -p 1 [-N ] [-h ] [-D] [] ... ``` diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index 6cc2fbcc8..ea19c8ee3 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -3,6 +3,8 @@ * */ +#define NO_MIXED_CIRCUITS + #include "Networking/Server.h" #include "Networking/CryptoPlayer.h" #include "Math/gfp.h" @@ -43,14 +45,13 @@ int main(int argc, const char** argv) string prefix = get_prep_sub_dir(PREP_DIR "ECDSA/", 2); read_mac_key(prefix, N, keyp); + pShare::MAC_Check::setup(P); + Share::MAC_Check::setup(P); + DataPositions usage; Sub_Data_Files prep(N, prefix, usage); typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); - BaseMachine machine; - machine.ot_setups.push_back({P, false}); - GC::ShareThread thread(N, - OnlineOptions::singleton, P, {}, usage); SubProcessor proc(_, MCp, prep, P); pShare sk, __; @@ -60,4 +61,8 @@ int main(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts); + + pShare::MAC_Check::teardown(); + Share::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index a68f8e833..07520f336 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -30,6 +30,8 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "Machines/ShamirMachine.hpp" +#include "Machines/MalRep.hpp" +#include "Machines/Rep.hpp" #include @@ -52,10 +54,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); pShare sk = typename T::Honest::Protocol(P).get_random(); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; DataPositions usage; @@ -69,4 +71,5 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); // check(tuples, sk, {}, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + P256Element::finish(); } diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp index dc2edab31..920397cef 100644 --- a/ECDSA/mascot-ecdsa-party.cpp +++ b/ECDSA/mascot-ecdsa-party.cpp @@ -3,6 +3,10 @@ * */ +#define NO_MIXED_CIRCUITS + +#define NO_SECURITY_CHECK + #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 1c4a442ae..550c0ac8a 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -92,9 +92,6 @@ void run(int argc, const char** argv) P256Element::init(); P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false); - BaseMachine machine; - machine.ot_setups.push_back({P, true}); - P256Element::Scalar keyp; SeededPRNG G; keyp.randomize(G); @@ -102,12 +99,13 @@ void run(int argc, const char** argv) typedef T pShare; DataPositions usage; + pShare::MAC_Check::setup(P); + T::MAC_Check::setup(P); + OnlineOptions::singleton.batch_size = 1; typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); typename pShare::TriplePrep sk_prep(0, usage); - GC::ShareThread thread(N, - OnlineOptions::singleton, P, {}, usage); SubProcessor sk_proc(_, MCp, sk_prep, P); pShare sk, __; // synchronize @@ -115,10 +113,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); sk_prep.get_two(DATA_INVERSE, sk, __); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; typename pShare::TriplePrep prep(0, usage); @@ -139,4 +137,8 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); //check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + + pShare::MAC_Check::teardown(); + T::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 334d5d1ba..0a5e0ab9c 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -41,8 +41,8 @@ void preprocessing(vector>& tuples, int buffer_size, timer.start(); Player& P = proc.P; auto& prep = proc.DataF; - size_t start = P.sent + prep.data_sent(); - auto stats = P.comm_stats + prep.comm_stats(); + size_t start = P.total_comm().sent; + auto stats = P.total_comm(); auto& extra_player = P; auto& protocol = proc.protocol; @@ -77,7 +77,7 @@ void preprocessing(vector>& tuples, int buffer_size, MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < buffer_size; i++) protocol.prepare_mul(inv_ks[i], sk); protocol.start_exchange(); @@ -106,9 +106,9 @@ void preprocessing(vector>& tuples, int buffer_size, timer.stop(); cout << "Generated " << buffer_size << " tuples in " << timer.elapsed() << " seconds, throughput " << buffer_size / timer.elapsed() << ", " - << 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size + << 1e-3 * (P.total_comm().sent - start) / buffer_size << " kbytes per tuple" << endl; - (P.comm_stats + prep.comm_stats() - stats).print(true); + (P.total_comm() - stats).print(true); } template class T> diff --git a/ECDSA/semi-ecdsa-party.cpp b/ECDSA/semi-ecdsa-party.cpp index 6bdcec286..d7a4d8836 100644 --- a/ECDSA/semi-ecdsa-party.cpp +++ b/ECDSA/semi-ecdsa-party.cpp @@ -10,6 +10,7 @@ #include "Protocols/SemiPrep.hpp" #include "Protocols/SemiInput.hpp" #include "Protocols/MAC_Check_Base.hpp" +#include "GC/SemiSecret.hpp" #include "ot-ecdsa-party.hpp" #include diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index 10991276a..5686349e2 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length, (void) pk; Timer timer; timer.start(); - size_t start = P.sent; - auto stats = P.comm_stats; + auto stats = P.total_comm(); EcSignature signature; vector opened_R; if (opts.R_after_msg) @@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length, auto& protocol = proc->protocol; if (proc) { - protocol.init_mul(proc); + protocol.init_mul(); protocol.prepare_mul(sk, tuple.a); protocol.start_exchange(); } @@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length, auto rx = tuple.R.x(); signature.s = MC.open( tuple.a * hash_to_scalar(message, length) + prod * rx, P); + auto diff = (P.total_comm() - stats); cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending " - << (P.sent - start) << " bytes" << endl; - auto diff = (P.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(true); return signature; } @@ -139,11 +138,11 @@ void sign_benchmark(vector>& tuples, T sk, P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); P256Element pk = MCc.open(sk, P); MCc.Check(P); cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); for (size_t i = 0; i < min(10lu, tuples.size()); i++) { @@ -154,13 +153,12 @@ void sign_benchmark(vector>& tuples, T sk, Timer timer; timer.start(); auto& check_player = MCp.get_check_player(P); - auto stats = check_player.comm_stats; - auto start = check_player.sent; + auto stats = check_player.total_comm(); MCp.Check(P); MCc.Check(P); + auto diff = (check_player.total_comm() - stats); cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending " - << (check_player.sent - start) << " bytes" << endl; - auto diff = (check_player.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(); } } diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 12ba1c938..fc5571b1d 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -8,24 +8,92 @@ #include "Networking/ssl_sockets.h" +#ifdef NO_CLIENT_TLS +class client_ctx +{ +public: + client_ctx(string) + { + } +}; + +class client_socket +{ +public: + int socket; + + client_socket(boost::asio::io_service&, + client_ctx&, int plaintext_socket, string, + string, bool) : socket(plaintext_socket) + { + } + + ~client_socket() + { + close(socket); + } +}; + +inline void send(client_socket* socket, octet* data, size_t len) +{ + send(socket->socket, data, len); +} + +inline void receive(client_socket* socket, octet* data, size_t len) +{ + receive(socket->socket, data, len); +} + +#else + +typedef ssl_ctx client_ctx; +typedef ssl_socket client_socket; + +#endif + +/** + * Client-side interface + */ class Client { vector plain_sockets; - ssl_ctx ctx; + client_ctx ctx; ssl_service io_service; public: - vector sockets; + /** + * Sockets for cleartext communication + */ + vector sockets; + + /** + * Specification of computation domain + */ octetStream specification; + /** + * Start a new set of connections to computing parties. + * @param hostnames location of computing parties + * @param port_base port base + * @param my_client_id client identifier + */ Client(const vector& hostnames, int port_base, int my_client_id); ~Client(); + /** + * Securely input private values. + * @param values vector of integer-like values + */ template void send_private_inputs(const vector& values); - template - vector receive_outputs(int n); + /** + * Securely receive output values. + * @param n number of values + * @returns vector of integer-like values + */ + template + vector receive_outputs(int n); }; #endif /* EXTERNALIO_CLIENT_H_ */ diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index 601d9a486..46e6fc32c 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -5,10 +5,8 @@ #include "Client.h" -inline -Client::Client(const vector& hostnames, int port_base, - int my_client_id) : - ctx("C" + to_string(my_client_id)) +inline Client::Client(const vector& hostnames, int port_base, int my_client_id) + : ctx("C" + to_string(my_client_id)) { bigint::init_thread(); @@ -20,15 +18,14 @@ Client::Client(const vector& hostnames, int port_base, { set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i); octetStream(to_string(my_client_id)).Send(plain_sockets[i]); - sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i], - "P" + to_string(i), "C" + to_string(my_client_id), true); + // sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i], + // "P" + to_string(i), "C" + to_string(my_client_id), true); if (i == 0) specification.Receive(sockets[0]); } } -inline -Client::~Client() +inline Client::~Client() { for (auto& socket : sockets) { @@ -37,24 +34,28 @@ Client::~Client() } // Send the private inputs masked with a random value. -// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid. -// Add the private input value to triple[0] and send to each spdz engine. -template +// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are +// valid. Add the private input value to triple[0] and send to each spdz engine. +template void Client::send_private_inputs(const vector& values) { int num_inputs = values.size(); octetStream os; - vector< vector > triples(num_inputs, vector(3)); + vector > triples(num_inputs, vector(3)); vector triple_shares(3); // Receive num_inputs triples from SPDZ for (size_t j = 0; j < sockets.size(); j++) { +#ifdef VERBOSE_COMM + cerr << "receiving from " << j << endl << flush; +#endif + os.reset_write_head(); os.Receive(sockets[j]); #ifdef VERBOSE_COMM - cerr << "received " << os.get_length() << " from " << j << endl; + cerr << "received " << os.get_length() << " from " << j << endl << flush; #endif for (int j = 0; j < num_inputs; j++) @@ -91,8 +92,8 @@ void Client::send_private_inputs(const vector& values) // Receive shares of the result and sum together. // Also receive authenticating values. -template -vector Client::receive_outputs(int n) +template +vector Client::receive_outputs(int n) { vector triples(3 * n); octetStream os; @@ -101,7 +102,7 @@ vector Client::receive_outputs(int n) os.reset_write_head(); os.Receive(socket); #ifdef VERBOSE_COMM - cout << "received " << os.get_length() << endl; + cout << "received " << os.get_length() << endl << flush; #endif for (int j = 0; j < 3 * n; j++) { @@ -111,7 +112,7 @@ vector Client::receive_outputs(int n) } } - vector output_values; + vector output_values; for (int i = 0; i < 3 * n; i += 3) { if (T(triples[i] * triples[i + 1]) != triples[i + 2]) diff --git a/ExternalIO/README.md b/ExternalIO/README.md index 02b4e1e8e..b841b0bbc 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -1,24 +1,30 @@ -The ExternalIO directory contains an example of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md). +The ExternalIO directory contains an example of managing I/O between +external client processes and parties running MP-SPDZ engines. These +instructions assume that MP-SPDZ has been built as per the [project +readme](../README.md). ## Working Examples -[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a +[bankers-bonus-client.cpp](../ExternalIO/bankers-bonus-client.cpp) and +[bankers-bonus-client.py](../ExternalIO/bankers-bonus-client.py) act as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output as described by [Damgård et al.](https://eprint.iacr.org/2015/1006) The computation allows up to eight clients to input a number and computes the client -with the largest input. You can run it as follows from the main +with the largest input. You can run the C++ code as follows from the main directory: ``` make bankers-bonus-client.x ./compile.py bankers_bonus 1 Scripts/setup-ssl.sh Scripts/setup-clients.sh 3 -Scripts/.sh & +PLAYERS= Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 0 100 0 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 ``` +`` can be any arithmetic protocol (e.g., `mascot`) but not a +binary protocol (e.g., `yao`). This should output that the winning id is 1. Note that the ids have to be incremental, and the client with the highest id has to input 1 as the last argument while the others have to input 0 there. Furthermore, @@ -28,58 +34,30 @@ protocol script. The setup scripts generate the necessary SSL certificates and keys. Therefore, if you run the computation on different hosts, you will have to distribute the `*.pem` files. +For the Python client, make sure to install +[gmpy2](https://pypi.org/project/gmpy2), and run +`ExternalIO/bankers-bonus-client.py` instead of +`bankers-bonus-client.x`. + ## I/O MPC Instructions ### Connection Setup -**listen**(*int port_num*) - -Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background. - -*port_num* - the port number to listen on. - -**acceptclientconnection**(*regint client_socket_id*, *int port_num*) - -Picks the first available client socket connection. Blocks if none available. - -*client_socket_id* - an identifier used to refer to the client socket. - -*port_num* - the port number identifies the socket server to accept connections on. +1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients) +2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection) +3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection) ### Data Exchange -Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py). - -*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*) - -Read a share of an input from a client, blocking on the client send. - -*client_socket_id* - an identifier used to refer to the client socket. - -*number_of_inputs* - the number of inputs expected - -*[inputs]* - returned list of shares of private input. - -**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*) - -Write shares of values including macs to an external client. - -*client_socket_id* - an identifier used to refer to the client socket. - -*[values]* - list of shares of values to send to client. - -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. - -See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message. - -*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*) - -Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf) - -*number_of_inputs* - the number of inputs expected +Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types). -*client_socket_id* - an identifier used to refer to the client socket. +1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket) +2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client) +3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients) -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. +## Client-Side Interface -*[inputs]* - returned list of shares of private input. +The example uses the `Client` class implemented in +`ExternalIO/Client.hpp` to handle the communication, see +[this reference](https://mp-spdz.readthedocs.io/en/latest/io.html#reference) for +documentation. diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp index f68384c00..b040dd5e8 100644 --- a/ExternalIO/bankers-bonus-client.cpp +++ b/ExternalIO/bankers-bonus-client.cpp @@ -46,7 +46,7 @@ #include #include -template +template void one_run(T salary_value, Client& client) { // Run the computation @@ -54,18 +54,18 @@ void one_run(T salary_value, Client& client) cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; // Get the result back (client_id of winning client) - T result = client.receive_outputs(1)[0]; + U result = client.receive_outputs(1)[0]; cout << "Winning client id is : " << result << endl; } -template +template void run(double salary_value, Client& client) { // sint - one_run(long(round(salary_value)), client); + one_run(long(round(salary_value)), client); // sfix with f = 16 - one_run(long(round(salary_value * exp2(16))), client); + one_run(long(round(salary_value * exp2(16))), client); } int main(int argc, char** argv) @@ -125,7 +125,7 @@ int main(int argc, char** argv) { gfp::init_field(specification.get()); cerr << "using prime " << gfp::pr() << endl; - run(salary_value, client); + run(salary_value, client); break; } case 'R': @@ -134,13 +134,13 @@ int main(int argc, char** argv) switch (R) { case 64: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; case 104: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; case 128: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; default: cerr << R << "-bit ring not implemented"; diff --git a/ExternalIO/bankers-bonus-client.py b/ExternalIO/bankers-bonus-client.py new file mode 100755 index 000000000..d0f8d285b --- /dev/null +++ b/ExternalIO/bankers-bonus-client.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 + +import sys + +sys.path.append('.') + +from client import * +from domains import * + +client_id = int(sys.argv[1]) +n_parties = int(sys.argv[2]) +bonus = float(sys.argv[3]) +finish = int(sys.argv[4]) + +client = Client(['localhost'] * n_parties, 14000, client_id) + +type = client.specification.get_int(4) + +if type == ord('R'): + domain = Z2(client.specification.get_int(4)) +elif type == ord('p'): + domain = Fp(client.specification.get_bigint()) +else: + raise Exception('invalid type') + +for socket in client.sockets: + os = octetStream() + os.store(finish) + os.Send(socket) + +for x in bonus, bonus * 2 ** 16: + client.send_private_inputs([domain(x)]) + + print('Winning client id is :', + client.receive_outputs(domain, 1)[0].v % 2 ** 64) diff --git a/ExternalIO/client.py b/ExternalIO/client.py new file mode 100644 index 000000000..a6fd0b035 --- /dev/null +++ b/ExternalIO/client.py @@ -0,0 +1,115 @@ +import socket, ssl +import struct +import time + +class Client: + def __init__(self, hostnames, port_base, my_client_id): + ctx = ssl.SSLContext() + name = 'C%d' % my_client_id + prefix = 'Player-Data/%s' % name + ctx.load_cert_chain(certfile=prefix + '.pem', keyfile=prefix + '.key') + ctx.load_verify_locations(capath='Player-Data') + + self.sockets = [] + for i, hostname in enumerate(hostnames): + for j in range(10000): + try: + plain_socket = socket.create_connection( + (hostname, port_base + i)) + break + except ConnectionRefusedError: + if j < 60: + time.sleep(1) + else: + raise + octetStream(b'%d' % my_client_id).Send(plain_socket) + self.sockets.append(ctx.wrap_socket(plain_socket, + server_hostname='P%d' % i)) + + self.specification = octetStream() + self.specification.Receive(self.sockets[0]) + + def receive_triples(self, T, n): + triples = [[0, 0, 0] for i in range(n)] + os = octetStream() + for socket in self.sockets: + os.Receive(socket) + for triple in triples: + for i in range(3): + t = T() + t.unpack(os) + triple[i] += t + res = [] + for triple in triples: + prod = triple[0] * triple[1] + if prod != triple[2]: + raise Exception( + 'invalid triple, diff %s' % hex(prod.v - triple[2].v)) + return triples + + def send_private_inputs(self, values): + T = type(values[0]) + triples = self.receive_triples(T, len(values)) + os = octetStream() + assert len(values) == len(triples) + for value, triple in zip(values, triples): + (value + triple[0]).pack(os) + for socket in self.sockets: + os.Send(socket) + + def receive_outputs(self, T, n): + triples = self.receive_triples(T, n) + return [triple[0] for triple in triples] + +class octetStream: + def __init__(self, value=None): + self.buf = b'' + self.ptr = 0 + if value is not None: + self.buf += value + + def reset_write_head(self): + self.buf = b'' + self.ptr = 0 + + def Send(self, socket): + socket.sendall(struct.pack('= 0) + + def __add__(self, other): + try: + res = self.v + other.v + except: + res = self.v + other + return type(self)(res) + + def __mul__(self, other): + try: + res = self.v * other.v + except: + res = self.v * other + return type(self)(res) + + __radd__ = __add__ + + def __eq__(self, other): + return self.v == other.v + + def __neq__(self, other): + return self.v != other.v + + def unpack(self, os): + self.v = 0 + buf = os.consume(self.n_bytes) + for i, b in enumerate(buf): + self.v += b << (i * 8) + + def pack(self, os): + v = self.v + temp_buf = [] + for i in range(self.n_bytes): + temp_buf.append(v & 0xff) + v >>= 8 + #Instead of using python a loop per value we let struct pack handle all it + os.buf += struct.pack('<{}B'.format(len(temp_buf)), *tuple(temp_buf)) + +def Z2(k): + class Z(Domain): + modulus = 2 ** k + n_words = (k + 63) // 64 + n_bytes = (k + 7) // 8 + + return Z + +def Fp(mod): + import gmpy2 + + class Fp(Domain): + modulus = mod + n_words = (modulus.bit_length() + 63) // 64 + n_bytes = 8 * n_words + R = 2 ** (64 * n_words) % modulus + R_inv = gmpy2.invert(R, modulus) + + def unpack(self, os): + Domain.unpack(self, os) + self.v = self.v * self.R_inv % self.modulus + + def pack(self, os): + Domain.pack(type(self)(self.v * self.R), os) + + return Fp diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h index 1efe1e228..b0a287444 100644 --- a/FHE/AddableVector.h +++ b/FHE/AddableVector.h @@ -58,7 +58,8 @@ class AddableVector: public vector { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - for (unsigned int i = 0; i < this->size(); i++) + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) (*this)[i] += y[i]; return *this; } @@ -67,9 +68,11 @@ class AddableVector: public vector { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - AddableVector res(y.size()); - for (unsigned int i = 0; i < this->size(); i++) - res[i] = (*this)[i] - y[i]; + AddableVector res; + res.reserve(y.size()); + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) + res.push_back((*this)[i] - y[i]); return res; } diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 9afef83ce..62cbd5281 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -31,6 +31,12 @@ word check_pk_id(word a, word b) } +void Ciphertext::Scale() +{ + Scale(params->get_plaintext_modulus()); +} + + void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1) { if (c0.params!=c1.params) { throw params_mismatch(); } @@ -115,9 +121,28 @@ void Ciphertext::add(octetStream& os) *this += tmp; } +void Ciphertext::rerandomize(const FHE_PK& pk) +{ + Rq_Element tmp(*params); + SeededPRNG G; + vector r(params->FFTD()[0].m()); + bigint p = pk.p(); + assert(p != 0); + for (auto& x : r) + { + G.get(x, params->p0().numBits() - p.numBits() - 1); + x *= p; + } + tmp.from(r, 0); + Scale(); + cc0 += tmp; + auto zero = pk.encrypt(*params); + zero.Scale(pk.p()); + *this += zero; +} + template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); -template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); - - +template void mul(Ciphertext& ans, const Plaintext& a, + const Ciphertext& c); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index d455f1268..11a23e2ab 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -15,6 +15,12 @@ template void mul(Ciphertext& ans,const Ciphertext& c, void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1); void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,const FHE_PK& pk); +/** + * BGV ciphertext. + * The class allows adding two ciphertexts as well as adding a plaintext and + * a ciphertext via operator overloading. The multiplication of two ciphertexts + * requires the public key and thus needs a separate function. + */ class Ciphertext { Rq_Element cc0,cc1; @@ -54,6 +60,7 @@ class Ciphertext // Scale down an element from level 1 to level 0, if at level 0 do nothing void Scale(const bigint& p) { cc0.Scale(p); cc1.Scale(p); } + void Scale(); // Throws error if ans,c0,c1 etc have different params settings // - Thus programmer needs to ensure this rather than this being done @@ -90,6 +97,12 @@ class Ciphertext template Ciphertext& operator*=(const Plaintext_& other) { ::mul(*this, *this, other); return *this; } + /** + * Ciphertext multiplication. + * @param pk public key + * @param x second ciphertext + * @returns product ciphertext + */ Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const { Ciphertext res(*params); ::mul(res, *this, x, pk); return res; } @@ -98,14 +111,18 @@ class Ciphertext return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this}; } + /// Re-randomize for circuit privacy. + void rerandomize(const FHE_PK& pk); + int level() const { return cc0.level(); } - // pack/unpack (like IO) also assume params are known and already set - // correctly + /// Append to buffer void pack(octetStream& o) const { cc0.pack(o); cc1.pack(o); o.store(pk_id); } - void unpack(octetStream& o) - { cc0.unpack(o); cc1.unpack(o); o.get(pk_id); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& o) + { cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); } void output(ostream& s) const { cc0.output(s); cc1.output(s); s.write((char*)&pk_id, sizeof(pk_id)); } diff --git a/FHE/Diagonalizer.cpp b/FHE/Diagonalizer.cpp index 9cc1a0840..958cd28cc 100644 --- a/FHE/Diagonalizer.cpp +++ b/FHE/Diagonalizer.cpp @@ -64,8 +64,11 @@ Diagonalizer::MatrixVector Diagonalizer::dediag( { auto& c = products.at(i); for (int j = 0; j < n_matrices; j++) + { + res.at(j).entries.init(); for (size_t k = 0; k < n_rows; k++) res.at(j)[{k, i}] = c.element(j * n_rows + k); + } } return res; } diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index c71a4c5da..d3b67b506 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -7,6 +7,11 @@ +FFT_Data::FFT_Data() : + twop(-1) +{ +} + void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) { R=Rg; diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index c5d6b2063..4fb37ed48 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -50,7 +50,7 @@ class FFT_Data void pack(octetStream& o) const; void unpack(octetStream& o); - FFT_Data() { ; } + FFT_Data(); FFT_Data(const Ring& Rg,const Zp_Data& PrD) { init(Rg,PrD); } diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 2a4d6b123..742c85452 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -12,6 +12,11 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p()) { } +FHE_SK::FHE_SK(const FHE_Params& pms) : + FHE_SK(pms, pms.get_plaintext_modulus()) +{ +} + FHE_SK& FHE_SK::operator+=(const FHE_SK& c) { @@ -38,6 +43,11 @@ void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G) } +FHE_PK::FHE_PK(const FHE_Params& pms) : + FHE_PK(pms, pms.get_plaintext_modulus()) +{ +} + Rq_Element FHE_PK::sample_secret_key(PRNG& G) { Rq_Element sk = FHE_SK(*this).s(); @@ -47,11 +57,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G) } void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) +{ + Rq_Element a(*this); + a.randomize(G); + partial_key_gen(sk, a, G, noise_boost); +} + +void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost) { FHE_PK& PK = *this; - // Generate the main public key - PK.a0.randomize(G); + a0 = a; // b0=a0*s+p*e0 Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation); @@ -77,9 +94,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) mul(es,es,PK.pr); add(PK.Sw_b,PK.Sw_b,es); - // Lowering level as we only decrypt at level 0 - sk.lower_level(); - // bs=bs-p1*s^2 Rq_Element s2; mul(s2,sk,sk); // Mult at level 0 @@ -175,32 +189,51 @@ Ciphertext FHE_PK::encrypt(const Plaintext& template Ciphertext FHE_PK::encrypt( const Plaintext& mess) const +{ + return encrypt(Rq_Element(*params, mess)); +} + +Ciphertext FHE_PK::encrypt(const Rq_Element& mess) const { Random_Coins rc(*params); PRNG G; G.ReSeed(); rc.generate(G); - return encrypt(mess, rc); + Ciphertext res(*params); + quasi_encrypt(res, mess, rc); + return res; } template void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const { - if (&c.get_params()!=params) { throw params_mismatch(); } if (T::characteristic_two ^ (pr == 2)) throw pr_mismatch(); + Rq_Element ans = quasi_decrypt(c); + mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); +} + +Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const +{ + if (&c.get_params()!=params) { throw params_mismatch(); } + Rq_Element ans; mul(ans,c.c1(),sk); sub(ans,c.c0(),ans); ans.change_rep(polynomial); - mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); + return ans; } +Plaintext_ FHE_SK::decrypt(const Ciphertext& c) +{ + return decrypt(c, params->get_plaintext_field_data()); +} + template Plaintext FHE_SK::decrypt(const Ciphertext& c, const FD& FieldD) { @@ -295,12 +328,12 @@ void FHE_PK::unpack(octetStream& o) o.consume((octet*) tag, 8); if (memcmp(tag, "PKPKPKPK", 8)) throw runtime_error("invalid serialization of public key"); - a0.unpack(o); - b0.unpack(o); + a0.unpack(o, *params); + b0.unpack(o, *params); if (params->n_mults() > 0) { - Sw_a.unpack(o); - Sw_b.unpack(o); + Sw_a.unpack(o, *params); + Sw_b.unpack(o, *params); } pr.unpack(o); } @@ -318,7 +351,6 @@ bool FHE_PK::operator!=(const FHE_PK& x) const return false; } - void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const { @@ -334,15 +366,13 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, template void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) { - check(*params, pk, pr); + check(*params, pk, FieldD.get_prime()); pk.check_noise(*this); if (decrypt(pk.encrypt(Plaintext_(FieldD)), FieldD) != Plaintext_(FieldD)) throw runtime_error("incorrect key pair"); } - - void FHE_PK::check(const FHE_Params& params, const bigint& pr) const { if (this->pr != pr) @@ -357,6 +387,24 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const } } +bigint FHE_SK::get_noise(const Ciphertext& c) +{ + sk.lower_level(); + Ciphertext cc = c; + if (cc.level()) + cc.Scale(); + Rq_Element tmp = quasi_decrypt(cc); + bigint res; + bigint q = tmp.get_modulus(); + bigint half_q = q / 2; + for (auto& x : tmp.to_vec_bigint()) + { +// cout << numBits(x) << "/" << (x > half_q) << "/" << (x < 0) << " "; + res = max(res, x > half_q ? x - q : x); + } + return res; +} + template void FHE_PK::encrypt(Ciphertext&, const Plaintext_& mess, diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 72a7ddfa8..f342e203b 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -12,6 +12,10 @@ class FHE_PK; class Ciphertext; +/** + * BGV secret key. + * The class allows addition. + */ class FHE_SK { Rq_Element sk; @@ -29,6 +33,8 @@ class FHE_SK // secret key always on lower level void assign(const Rq_Element& s) { sk=s; sk.lower_level(); } + FHE_SK(const FHE_Params& pms); + FHE_SK(const FHE_Params& pms, const bigint& p) : sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; } @@ -38,8 +44,11 @@ class FHE_SK const Rq_Element& s() const { return sk; } + /// Append to buffer void pack(octetStream& os) const { sk.pack(os); pr.pack(os); } - void unpack(octetStream& os) { sk.unpack(os); pr.unpack(os); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& os) { sk.unpack(os, *params); pr.unpack(os); } // Assumes Ring and prime of mess have already been set correctly // Ciphertext c must be at level 0 or an error occurs @@ -50,9 +59,14 @@ class FHE_SK template Plaintext decrypt(const Ciphertext& c, const FD& FieldD); + /// Decryption for cleartexts modulo prime + Plaintext_ decrypt(const Ciphertext& c); + template void decrypt_any(Plaintext_& mess, const Ciphertext& c); + Rq_Element quasi_decrypt(const Ciphertext& c) const; + // Three stage procedure for Distributed Decryption // - First stage produces my shares // - Second stage adds in another players shares, do this once for each other player @@ -62,7 +76,6 @@ class FHE_SK void dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_number,int num_players) const; void dist_decrypt_2(vector& vv,const vector& vv1) const; - friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); /* Add secret keys @@ -82,10 +95,15 @@ class FHE_SK template void check(const FHE_PK& pk, const FD& FieldD); + bigint get_noise(const Ciphertext& c); + friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; } }; +/** + * BGV public key. + */ class FHE_PK { Rq_Element a0,b0; @@ -104,8 +122,10 @@ class FHE_PK ) { a0=a; b0=b; Sw_a=sa; Sw_b=sb; } - - FHE_PK(const FHE_Params& pms, const bigint& p = 0) + + FHE_PK(const FHE_Params& pms); + + FHE_PK(const FHE_Params& pms, const bigint& p) : a0(pms.FFTD(),evaluation,evaluation), b0(pms.FFTD(),evaluation,evaluation), Sw_a(pms.FFTD(),evaluation,evaluation), @@ -143,19 +163,26 @@ class FHE_PK template Ciphertext encrypt(const Plaintext& mess, const Random_Coins& rc) const; + + /// Encryption template Ciphertext encrypt(const Plaintext& mess) const; + Ciphertext encrypt(const Rq_Element& mess) const; friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); Rq_Element sample_secret_key(PRNG& G); void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1); + void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost = 1); void check_noise(const FHE_SK& sk) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const; - // params setting is done out of these IO/pack/unpack functions + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); bool operator!=(const FHE_PK& x) const; @@ -168,21 +195,39 @@ class FHE_PK void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); +/** + * BGV key pair + */ class FHE_KeyPair { public: + /// Public key FHE_PK pk; + /// Secret key FHE_SK sk; - FHE_KeyPair(const FHE_Params& params, const bigint& pr = 0) : + FHE_KeyPair(const FHE_Params& params, const bigint& pr) : pk(params, pr), sk(params, pr) { } + /// Initialization + FHE_KeyPair(const FHE_Params& params) : + pk(params), sk(params) + { + } + void generate(PRNG& G) { KeyGen(pk, sk, G); } + + /// Generate fresh keys + void generate() + { + SeededPRNG G; + generate(G); + } }; template diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 8ae6c2885..5a0f3991c 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -1,7 +1,15 @@ #include "FHE_Params.h" +#include "NTL-Subs.h" #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" +#include "Protocols/HemiOptions.h" +#include "Processor/OnlineOptions.h" + +FHE_Params::FHE_Params(int n_mults, int drown_sec) : + FFTData(n_mults + 1), Chi(0.7), sec_p(drown_sec), matrix_dim(1) +{ +} void FHE_Params::set(const Ring& R, const vector& primes) @@ -12,16 +20,35 @@ void FHE_Params::set(const Ring& R, for (size_t i = 0; i < FFTData.size(); i++) FFTData[i].init(R,primes[i]); - set_sec(40); + set_sec(sec_p); } void FHE_Params::set_sec(int sec) { + assert(sec >= 0); sec_p=sec; Bval=1; Bval=Bval< 0); + if (FFTData[0].get_prime() != 0) + throw runtime_error("cannot change matrix dimension after parameter generation"); + this->matrix_dim = matrix_dim; +} + +void FHE_Params::set_matrix_dim_from_options() +{ + set_matrix_dim( + HemiOptions::singleton.plain_matmul ? + 1 : OnlineOptions::singleton.batch_size); } bigint FHE_Params::Q() const @@ -40,6 +67,8 @@ void FHE_Params::pack(octetStream& o) const Chi.pack(o); Bval.pack(o); o.store(sec_p); + o.store(matrix_dim); + fd.pack(o); } void FHE_Params::unpack(octetStream& o) @@ -52,6 +81,8 @@ void FHE_Params::unpack(octetStream& o) Chi.unpack(o); Bval.unpack(o); o.get(sec_p); + o.get(matrix_dim); + fd.unpack(o); } bool FHE_Params::operator!=(const FHE_Params& other) const @@ -64,3 +95,37 @@ bool FHE_Params::operator!=(const FHE_Params& other) const else return false; } + +void FHE_Params::basic_generation_mod_prime(int plaintext_length) +{ + if (n_mults() == 0) + generate_semi_setup(plaintext_length, 0, *this, fd, false); + else + { + Parameters parameters(1, plaintext_length, 0); + parameters.generate_setup(*this, fd); + } +} + +template<> +const FFT_Data& FHE_Params::get_plaintext_field_data() const +{ + return fd; +} + +template<> +const P2Data& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +template<> +const PPData& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +bigint FHE_Params::get_plaintext_modulus() const +{ + return fd.get_prime(); +} diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index ac56668a2..4733245ca 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -13,7 +13,11 @@ #include "FHE/FFT_Data.h" #include "FHE/DiscreteGauss.h" #include "Tools/random.h" +#include "Protocols/config.h" +/** + * Cryptosystem parameters + */ class FHE_Params { protected: @@ -26,19 +30,29 @@ class FHE_Params // Data for distributed decryption int sec_p; bigint Bval; + int matrix_dim; + + FFT_Data fd; public: - FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {} + /** + * Initialization. + * @param n_mults number of ciphertext multiplications (0/1) + * @param drown_sec parameter for function privacy (default 40) + */ + FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY); int n_mults() const { return FFTData.size() - 1; } - // Rely on default copy assignment/constructor (not that they should - // ever be needed) - void set(const Ring& R,const vector& primes); void set(const vector& primes); void set_sec(int sec); + void set_min_sec(int sec); + + void set_matrix_dim(int matrix_dim); + void set_matrix_dim_from_options(); + int get_matrix_dim() const { return matrix_dim; } const vector& FFTD() const { return FFTData; } @@ -55,10 +69,24 @@ class FHE_Params int phi_m() const { return FFTData[0].phi_m(); } const Ring& get_ring() { return FFTData[0].get_R(); } + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer void unpack(octetStream& o); bool operator!=(const FHE_Params& other) const; + + /** + * Generate parameter for computation modulo a prime + * @param plaintext_length bit length of prime + */ + void basic_generation_mod_prime(int plaintext_length); + + template + const FD& get_plaintext_field_data() const; + + bigint get_plaintext_modulus() const; }; #endif diff --git a/FHE/Matrix.cpp b/FHE/Matrix.cpp index c9c23aaba..dcec137e4 100644 --- a/FHE/Matrix.cpp +++ b/FHE/Matrix.cpp @@ -68,6 +68,10 @@ void HNF(matrix& H,matrix& U,const matrix& A) { int m=A.size(),n=A[0].size(),r,i,j,k; +#ifdef VERBOSE + cerr << "HNF m=" << m << ", n=" << n << endl; +#endif + H=A; ident(U,n); r=min(m,n); @@ -79,9 +83,9 @@ void HNF(matrix& H,matrix& U,const matrix& A) { if (step==2) { // Step 2 k=-1; - mn=bigint(1)<<256; + mn=bigint(0); for (j=i; j 0) cout << "+" << lgp1; cout << " and " << phi_N(m) << " slots" << endl; +#endif int extra_slack = 0; if (round_up) @@ -124,8 +127,10 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b } extra_slack = i - 1; lgp0 += extra_slack; +#ifdef VERBOSE cout << "Rounding up to " << lgp0 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } Ring R; @@ -147,11 +152,15 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool round_up, FHE_Params& params) { + (void) lg2pi, (void) n; + +#ifdef VERBOSE if (n >= 2 and n <= 10) cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2] << ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl; cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl; cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl; +#endif int extra_slack = 0; if (round_up) @@ -170,20 +179,18 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, extra_slack = 2 * i; lg2p0 += i; lg2p1 += i; +#ifdef VERBOSE cout << "Rounding up to " << lg2p0 << "+" << lg2p1 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } +#ifdef VERBOSE cout << "Total length: " << lg2p0 + lg2p1 << endl; +#endif return extra_slack; } - - - -/****************************************************************************** - * Here onwards needs NTL - ******************************************************************************/ @@ -220,12 +227,21 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p, { double phi_m_bound = NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1); + +#ifdef VERBOSE cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl; +#endif + if (phi_N(m) < phi_m_bound) { int old_m = m; + (void) old_m; m = 2 << int(ceil(log2(phi_m_bound))); + +#ifdef VERBOSE cout << "m = " << old_m << " too small, increasing it to " << m << endl; +#endif + generate_prime(p, numBits(p), m); } else @@ -249,6 +265,8 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p, void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, const string& i, const bigint& pr0) { + (void) i; + if (lg2pr==0) { throw invalid_params(); } bigint step=m; @@ -265,13 +283,14 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, assert(numBits(pr) == lg2pr); } +#ifdef VERBOSE cout << "\t pr" << i << " = " << pr << " : " << numBits(pr) << endl; + cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; +#endif assert(pr % m == 1); assert(pr % p == 1); assert(numBits(pr) == lg2pr); - - cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; } /* @@ -345,10 +364,12 @@ ZZX Cyclotomic(int N) return F; } #else +// simplified version powers of two int phi_N(int N) { if (((N - 1) & N) != 0) - throw runtime_error("compile with NTL support"); + throw runtime_error( + "compile with NTL support (USE_NTL=1 in CONFIG.mine)"); else if (N == 1) return 1; else @@ -398,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly) for (int i=0; i Fi(Gord); vector Rts(Gord); @@ -595,6 +640,27 @@ void char_2_dimension(int& m, int& lg2) m=5797; lg2=40; break; + case 64: + m = 9615; + break; + case 63: + m = 9271; + break; + case 28: + m = 3277; + break; + case 16: + m = 4369; + break; + case 15: + m = 4681; + break; + case 12: + m = 4095; + break; + case 11: + m = 2047; + break; default: throw runtime_error("field size not supported"); break; @@ -630,7 +696,7 @@ void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D) finalize_lengths(lg2p0, lg2p1, n, m, lg2pi[0], round_up, params); } - if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) > phi_N(m)) + if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) * 2 > m) throw runtime_error("number of slots too small"); cout << "m = " << m << endl; diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index ab150d272..7d7d13de6 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -1,8 +1,6 @@ #ifndef _NTL_Subs #define _NTL_Subs -/* All these routines use NTL on the inside */ - #include "FHE/Ring.h" #include "FHE/FFT_Data.h" #include "FHE/P2Data.h" @@ -47,23 +45,29 @@ class Parameters }; -// Main setup routine (need NTL if online_only is false) +// Main setup routine void generate_setup(int nparties, int lgp, int lg2, int sec, bool skip_2 = false, int slack = 0, bool round_up = false); // semi-homomorphic, includes slack template int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, FD& FieldD, bool round_up); + FHE_Params& params, FD& FieldD, bool round_up, int n = 1); // field-independent semi-homomorphic setup -int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, +int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up); -// Everything else needs NTL void init(Ring& Rg, int m, bool generate_poly); void init(P2Data& P2D,const Ring& Rg); +namespace NTL +{ +class GF2X; +} + +NTL::GF2X get_F(const Ring& Rg); + // For use when we want p to be a specific value void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0, bigint& pr1, int n, int sec, bigint& p, FHE_Params& params); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index ae52fc62f..f4502317e 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -36,10 +36,12 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s); // unify parameters by taking maximum over TopGear or not bigint B_clean_top_gear = B_clean * 2; - bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); + bigint B_clean_not_top_gear = B_clean << max(slack - sec, 0); B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); + int matrix_dim = params.get_matrix_dim(); #ifdef NOISY + cout << "phi(m): " << phi_m << endl; cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; cout << "V_s: " << V_s << endl; cout << "c1: " << c1 << endl; @@ -48,9 +50,14 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, cout << "log(slack): " << slack << endl; cout << "B_clean: " << B_clean << endl; cout << "B_scale: " << B_scale << endl; + cout << "matrix dimension: " << matrix_dim << endl; + cout << "drown sec: " << params.secp() << endl; + cout << "sec: " << sec << endl; #endif - drown = 1 + n * (bigint(1) << sec); + assert(matrix_dim > 0); + assert(params.secp() >= 0); + drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp()); } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) @@ -68,8 +75,14 @@ double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma) { if (sigma <= 0) sigma = FHE_Params().get_R(); - // the constant was updated using Martin Albrecht's LWE estimator in Sep 2019 - return 37.8 * (log_q - log2(sigma)); + // the constant was updated using Martin Albrecht's LWE estimator in Mar 2022 + // found the following pairs for 128-bit security + // and alpha = 0.7 * sqrt(2*pi) / q + // m = 2048, log_2(q) = 68 + // m = 4096, log_2(q) = 138 + // m = 8192, log_2(q) = 302 + // m = 16384, log_2(q) = 560 + return 15.1 * log_q; } double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, const FHE_Params& params) @@ -92,7 +105,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants() { tp *= t; double lgtp = log(tp) / log(2.0); - if (C[i] < 0 && lgtp < FHE_epsilon) + if (C[i] < 0 && lgtp < -FHE_epsilon) { C[i] = pow(x, i); } @@ -114,7 +127,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, cout << "n: " << n << endl; cout << "sec: " << sec << endl; cout << "sigma: " << this->sigma << endl; - cout << "h: " << h << endl; cout << "B_clean size: " << numBits(B_clean) << endl; cout << "B_scale size: " << numBits(B_scale) << endl; cout << "B_KS size: " << numBits(B_KS) << endl; @@ -155,7 +167,7 @@ bigint NoiseBounds::min_p0(const bigint& p1) bigint NoiseBounds::min_p1() { - return drown * B_KS + 1; + return max(bigint(drown * B_KS), bigint((phi_m * p) << 10)); } bigint NoiseBounds::opt_p1() @@ -169,8 +181,10 @@ bigint NoiseBounds::opt_p1() // solve mpf_class s = (-b + sqrt(b * b - 4 * a * c)) / (2 * a); bigint res = ceil(s); +#ifdef VERBOSE cout << "Optimal p1 vs minimal: " << numBits(res) << "/" << numBits(min_p1()) << endl; +#endif return res; } @@ -182,8 +196,10 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1) { min_p0 *= 2; min_p1 *= 2; +#ifdef VERBOSE cout << "increasing lengths: " << numBits(min_p0) << "/" << numBits(min_p1) << endl; +#endif } lg2p1 = numBits(min_p1); lg2p0 = numBits(min_p0); diff --git a/FHE/NoiseBounds.h b/FHE/NoiseBounds.h index ccd50808a..565c663ef 100644 --- a/FHE/NoiseBounds.h +++ b/FHE/NoiseBounds.h @@ -42,6 +42,8 @@ class SemiHomomorphicNoiseBounds bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); } static double min_phi_m(int log_q, double sigma); static double min_phi_m(int log_q, const FHE_Params& params); + + bigint get_B_clean() { return B_clean; } }; // as per ePrint 2012:642 for slack = 0 diff --git a/FHE/P2Data.cpp b/FHE/P2Data.cpp index 7d9a8ca47..ac4ae6f16 100644 --- a/FHE/P2Data.cpp +++ b/FHE/P2Data.cpp @@ -55,13 +55,13 @@ void P2Data::check_dimensions() const // cout << "Ai: " << Ai.size() << "x" << Ai[0].size() << endl; if (A.size() != Ai.size()) throw runtime_error("forward and backward mapping dimensions mismatch"); - if (A.size() != A[0].size()) + if (A.size() != A.at(0).size()) throw runtime_error("forward mapping not square"); - if (Ai.size() != Ai[0].size()) + if (Ai.size() != Ai.at(0).size()) throw runtime_error("backward mapping not square"); - if ((int)A[0].size() != slots * gf2n_short::degree()) + if ((int)A.at(0).size() != slots * gf2n_short::degree()) throw runtime_error( - "mapping dimension incorrect: " + to_string(A[0].size()) + "mapping dimension incorrect: " + to_string(A.at(0).size()) + " != " + to_string(slots) + " * " + to_string(gf2n_short::degree())); } diff --git a/FHE/PPData.cpp b/FHE/PPData.cpp index 43bd48293..d9acc05d3 100644 --- a/FHE/PPData.cpp +++ b/FHE/PPData.cpp @@ -40,10 +40,9 @@ void PPData::to_eval(vector& elem) const */ } -void PPData::from_eval(vector& elem) const +void PPData::from_eval(vector&) const { // avoid warning - elem.empty(); throw not_implemented(); /* diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index 84cbb9d19..4eba6e8f0 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -11,10 +11,43 @@ +template +Plaintext::Plaintext(const FHE_Params& params) : + Plaintext(params.get_plaintext_field_data(), Both) +{ +} + + +template +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).phi_m(); +} + +template +int Plaintext::degree() const +{ + return (*Field_Data).phi_m(); +} + + +template<> +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).num_slots(); +} + +template<> +int Plaintext::degree() const +{ + return (*Field_Data).degree(); +} + + template<> void Plaintext::from(const Generator& source) const { - b.resize(degree); + b.resize(degree()); for (auto& x : b) { source.get(bigint::tmp); @@ -31,7 +64,7 @@ void Plaintext::from_poly() const Ring_Element e(*Field_Data,polynomial); e.from(b); e.change_rep(evaluation); - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::from_poly() const for (unsigned int i=0; iget_prD()}; type=Both; @@ -90,7 +123,7 @@ template<> void Plaintext::from_poly() const { if (type!=Polynomial) { return; } - a.resize(n_slots); + a.resize(num_slots()); (*Field_Data).backward(a,b); type=Both; } @@ -106,34 +139,13 @@ void Plaintext::to_poly() const -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).num_slots(); - degree = (*Field_Data).degree(); -} - - template void Plaintext::allocate(PT_Type type) const { if (type != Evaluation) - b.resize(degree); + b.resize(degree()); if (type != Polynomial) - a.resize(n_slots); + a.resize(num_slots()); this->type = type; } @@ -141,7 +153,7 @@ void Plaintext::allocate(PT_Type type) const template void Plaintext::allocate_slots(const bigint& value) { - b.resize(degree); + b.resize(degree()); for (auto& x : b) x.allocate_slots(value); } @@ -236,7 +248,7 @@ void Plaintext::randomize(PRNG& G,condition cond) type=Polynomial; break; case Diagonal: - a.resize(n_slots); + a.resize(num_slots()); a[0].randomize(G); for (unsigned int i=1; i::randomize(PRNG& G,condition cond) break; default: // Gen a plaintext with 0/1 in each slot - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t) b[0].generateUniform(G, n_bits, false); } else - for (int i = 0; i < n_slots; i++) + for (size_t i = 0; i < num_slots(); i++) b[i].generateUniform(G, n_bits, false); break; default: @@ -288,7 +300,7 @@ void Plaintext::assign_zero(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::assign_one(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::equals(const Plaintext& x) const if (type!=Polynomial and x.type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::unpack(octetStream& o) unsigned int size; o.get(size); allocate(); - if (size != b.size()) + if (size != b.size() and size != 0) throw length_error("unexpected length received"); - for (unsigned int i = 0; i < b.size(); i++) + for (unsigned int i = 0; i < size; i++) b[i] = o.get(); } diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h index 52ff8b6d4..c8fb93c73 100644 --- a/FHE/Plaintext.h +++ b/FHE/Plaintext.h @@ -18,6 +18,7 @@ */ #include "FHE/Generator.h" +#include "FHE/FFT_Data.h" #include "Math/fixint.h" #include @@ -25,6 +26,8 @@ using namespace std; class FHE_PK; class Rq_Element; +class FHE_Params; +class FFT_Data; template class AddableVector; // Forward declaration as apparently this is needed for friends in templates @@ -38,13 +41,19 @@ enum condition { Full, Diagonal, Bits }; enum PT_Type { Polynomial, Evaluation, Both }; +/** + * BGV plaintext. + * Use ``Plaintext_mod_prime`` instead of filling in the templates. + * The plaintext is held in one of the two representations or both, + * polynomial and evaluation. The latter is the one allowing element-wise + * multiplication over a vector. + * Plaintexts can be added, subtracted, and multiplied via operator overloading. + */ template class Plaintext { typedef typename FD::poly_type S; - int n_slots; - int degree; mutable vector a; // The thing in evaluation/FFT form mutable vector b; // Now in polynomial form @@ -58,33 +67,47 @@ class Plaintext const FD *Field_Data; - void set_sizes(); + int degree() const; public: const FD& get_field() const { return *Field_Data; } - unsigned int num_slots() const { return n_slots; } + + /// Number of slots + unsigned int num_slots() const; Plaintext(const FD& FieldD, PT_Type type = Polynomial) - { Field_Data=&FieldD; set_sizes(); allocate(type); } + { Field_Data=&FieldD; allocate(type); } Plaintext(const FD& FieldD, const Rq_Element& other); + /// Initialization + Plaintext(const FHE_Params& params); + void allocate(PT_Type type) const; void allocate() const { allocate(type); } void allocate_slots(const bigint& value); int get_min_alloc(); - // Access evaluation representation + /** + * Read slot. + * @param i slot number + * @returns slot content + */ T element(int i) const { if (type==Polynomial) { from_poly(); } return a[i]; } + /** + * Write to slot + * @param i slot number + * @param e new slot content + */ void set_element(int i,const T& e) { if (type==Polynomial) { throw not_implemented(); } - a.resize(n_slots); + a.resize(num_slots()); a[i]=e; type=Evaluation; } @@ -171,10 +194,10 @@ class Plaintext bool is_diagonal() const; - /* Pack and unpack into an octetStream - * For unpack we assume the FFTD has been assigned correctly already - */ + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); size_t report_size(ReportType type); @@ -185,4 +208,6 @@ class Plaintext template using Plaintext_ = Plaintext; +typedef Plaintext_ Plaintext_mod_prime; + #endif diff --git a/FHE/Ring.cpp b/FHE/Ring.cpp index c1c318b8d..3b63f3069 100644 --- a/FHE/Ring.cpp +++ b/FHE/Ring.cpp @@ -24,7 +24,7 @@ void Ring::unpack(octetStream& o) o.get(pi_inv); o.get(poly); } - else + else if (mm != 0) init(*this, mm); } diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 9c2545ed8..39690fa6a 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -50,6 +50,7 @@ void Ring_Element::prepare_push() void Ring_Element::allocate() { + assert(FFTD); element.resize(FFTD->phi_m()); } @@ -86,7 +87,6 @@ void Ring_Element::negate() void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { - if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.element.empty()) { @@ -99,6 +99,8 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) return; } + if (a.rep!=b.rep) { throw rep_mismatch(); } + if (&ans == &a) { ans += b; @@ -401,19 +403,29 @@ void Ring_Element::change_rep(RepType r) bool Ring_Element::equals(const Ring_Element& a) const { - if (element.empty() and a.element.empty()) - return true; - else if (element.empty() or a.element.empty()) - throw not_implemented(); - if (rep!=a.rep) { throw rep_mismatch(); } if (*FFTD!=*a.FFTD) { throw pr_mismatch(); } + + if (is_zero() or a.is_zero()) + return is_zero() and a.is_zero(); + for (int i=0; i<(*FFTD).phi_m(); i++) { if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } } return true; } +bool Ring_Element::is_zero() const +{ + if (element.empty()) + return true; + for (auto& x : element) + if (not ::isZero(x, FFTD->get_prD())) + return false; + return true; +} + + ConversionIterator Ring_Element::get_iterator() const { if (rep != polynomial) @@ -560,6 +572,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const { if (&FFTD != this->FFTD) throw params_mismatch(); + if (is_zero()) + throw runtime_error("element is zero"); } diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index 5cc93ca9a..5982bbe32 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -95,6 +95,7 @@ class Ring_Element void randomize(PRNG& G,bool Diag=false); bool equals(const Ring_Element& a) const; + bool is_zero() const; // This is a NOP in cases where we cannot do a FFT void change_rep(RepType r); diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index af7a664b5..d6a14aabd 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -5,7 +5,7 @@ #include "Math/modp.hpp" Rq_Element::Rq_Element(const FHE_PK& pk) : - Rq_Element(pk.get_params().FFTD()) + Rq_Element(pk.get_params().FFTD(), evaluation, evaluation) { } @@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b) } } +void Rq_Element::add(octetStream& os) +{ + Rq_Element tmp(*this); + tmp.unpack(os); + *this += tmp; +} + void Rq_Element::randomize(PRNG& G,int l) { set_level(l); @@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p) // Now add delta back onto a0 Rq_Element bb(b0,b1); - add(*this,*this,bb); + ::add(*this,*this,bb); // Now divide by p1 mod p0 modp p1_inv,pp; @@ -340,6 +347,12 @@ size_t Rq_Element::report_size(ReportType type) const return sz; } +void Rq_Element::unpack(octetStream& o, const FHE_Params& params) +{ + set_data(params.FFTD()); + unpack(o); +} + void Rq_Element::print_first_non_zero() const { vector v = to_vec_bigint(); diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index d5e718419..4e0cdf97b 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -69,8 +69,9 @@ class Rq_Element a({b0}), lev(n_mults()) {} template - Rq_Element(const FHE_Params& params, const Plaintext& plaintext) : - Rq_Element(params) + Rq_Element(const FHE_Params& params, const Plaintext& plaintext, + RepType r0 = polynomial, RepType r1 = polynomial) : + Rq_Element(params, r0, r1) { from(plaintext.get_iterator()); } @@ -93,12 +94,14 @@ class Rq_Element friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); + void add(octetStream& os); + template Rq_Element& operator+=(const vector& other); - Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; } + Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; } - Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; } + Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; } Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; } template Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; } @@ -157,6 +160,9 @@ class Rq_Element void pack(octetStream& o) const; void unpack(octetStream& o); + // without prior initialization + void unpack(octetStream& o, const FHE_Params& params); + void output(ostream& s) const; void input(istream& s); @@ -176,7 +182,7 @@ Rq_Element& Rq_Element::operator+=(const vector& other) { Rq_Element tmp = *this; tmp.from(Iterator(other), lev); - add(*this, *this, tmp); + ::add(*this, *this, tmp); return *this; } diff --git a/FHE/Subroutines.cpp b/FHE/Subroutines.cpp index 8ec6a7cee..a688b1691 100644 --- a/FHE/Subroutines.cpp +++ b/FHE/Subroutines.cpp @@ -11,35 +11,15 @@ void Subs(modp& ans,const vector& poly,const modp& x,const Zp_Data& ZpD) assignZero(ans,ZpD); for (int i=poly.size()-1; i>=0; i--) { Mul(ans,ans,x,ZpD); - switch (poly[i]) - { case 0: - break; - case 1: - Add(ans,ans,one,ZpD); - break; - case -1: - Sub(ans,ans,one,ZpD); - break; - case 2: - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - break; - case -2: - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - break; - case 3: - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - break; - case -3: - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - break; - default: - throw not_implemented(); + if (poly[i] > 0) + { + for (int j = 0; j < poly[i]; j++) + Add(ans, ans, one, ZpD); + } + if (poly[i] < 0) + { + for (int j = 0; j < -poly[i]; j++) + Sub(ans, ans, one, ZpD); } } } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 0f5d1fe86..0dc8d9675 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -40,10 +40,9 @@ template void PartSetup::generate_setup(int n_parties, int plaintext_length, int sec, int slack, bool round_up) { - sec = max(sec, 40); + params.set_min_sec(sec); Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup( params, FieldD); - params.set_sec(sec); pk = FHE_PK(params, FieldD.get_prime()); sk = FHE_SK(params, FieldD.get_prime()); calpha = Ciphertext(params); @@ -180,11 +179,8 @@ void PartSetup::init_field() } template -void PartSetup::check(int sec) const +void PartSetup::check() const { - sec = max(sec, 40); - if (abs(sec - params.secp()) > 2) - throw runtime_error("security parameters vary too much between protocol and distributed decryption"); sk.check(params, pk, FieldD.get_prime()); } @@ -203,7 +199,7 @@ template void PartSetup::secure_init(Player& P, MachineBase& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); } template diff --git a/FHEOffline/DataSetup.h b/FHEOffline/DataSetup.h index 88c9e05fc..1c606506d 100644 --- a/FHEOffline/DataSetup.h +++ b/FHEOffline/DataSetup.h @@ -57,7 +57,7 @@ class PartSetup void init_field(); - void check(int sec) const; + void check() const; bool operator!=(const PartSetup& other); void secure_init(Player& P, MachineBase& machine, int plaintext_length, diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 732904b39..3df98c851 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -57,7 +57,7 @@ void Multiplier::multiply_and_add(Plaintext_& res, template void Multiplier::add(Plaintext_& res, const Ciphertext& c, - OT_ROLE role, int n_summands) + OT_ROLE role, int) { o.reset_write_head(); @@ -67,20 +67,10 @@ void Multiplier::add(Plaintext_& res, const Ciphertext& c, G.ReSeed(); timers["Mask randomization"].start(); product_share.randomize(G); - bigint B = 6 * machine.setup().params.get_R(); - B *= machine.setup().FieldD.get_prime(); - B <<= machine.drown_sec; - // slack - B *= NonInteractiveProof::slack(machine.sec, - machine.setup().params.phi_m()); - B <<= machine.extra_slack; - B *= n_summands; - rc.generateUniform(G, 0, B, B); + mask = c; + mask.rerandomize(other_pk); timers["Mask randomization"].stop(); - timers["Encryption"].start(); - other_pk.encrypt(mask, product_share, rc); - timers["Encryption"].stop(); - mask += c; + mask += product_share; mask.pack(o); res -= product_share; } @@ -130,6 +120,13 @@ void Multiplier::report_size(ReportType type, MemoryUsage& res) res += memory_usage; } +template +const vector& Multiplier::get_multiplicands( + const vector >& others_ct, const FHE_PK&) +{ + return others_ct[P.get_full_player().get_player(-P.get_offset())]; +} + template class Multiplier; template class Multiplier; diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h index e2e1ce660..9ab517a66 100644 --- a/FHEOffline/Multiplier.h +++ b/FHEOffline/Multiplier.h @@ -55,6 +55,9 @@ class Multiplier size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); size_t report_volatile() { return volatile_capacity; } + + const vector& get_multiplicands( + const vector>& others_ct, const FHE_PK&); }; #endif /* FHEOFFLINE_MULTIPLIER_H_ */ diff --git a/FHEOffline/PairwiseGenerator.cpp b/FHEOffline/PairwiseGenerator.cpp index ed5fb303e..0fb7a14d6 100644 --- a/FHEOffline/PairwiseGenerator.cpp +++ b/FHEOffline/PairwiseGenerator.cpp @@ -24,7 +24,7 @@ PairwiseGenerator::PairwiseGenerator(int thread_num, thread_num, machine.output, machine.get_prep_dir(P)), EC(P, machine.other_pks, machine.setup().FieldD, timers, machine, *this), MC(machine.setup().alphai), - n_ciphertexts(Proof::n_ciphertext_per_proof(machine.sec, machine.pk)), + n_ciphertexts(EC.proof.U), C(n_ciphertexts, machine.setup().params), volatile_memory(0), machine(machine) { @@ -175,7 +175,7 @@ size_t PairwiseGenerator::report_size(ReportType type) template size_t PairwiseGenerator::report_sent() { - return P.sent; + return P.total_comm().sent; } template diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index 6ac3a82d4..dd3f8968d 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -6,6 +6,7 @@ #include "FHEOffline/PairwiseMachine.h" #include "Tools/benchmarking.h" #include "Protocols/fake-stuff.h" +#include "Tools/Bundle.h" #include "Protocols/fake-stuff.hpp" @@ -16,19 +17,17 @@ PairwiseMachine::PairwiseMachine(Player& P) : { } -PairwiseMachine::PairwiseMachine(int argc, const char** argv) : - MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")), - other_pks(N.num_players(), {setup_p.params, 0}), - pk(other_pks[N.my_num()]), sk(pk) +RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) : + MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise")) { init(); } -void PairwiseMachine::init() +void RealPairwiseMachine::init() { if (use_gf2n) { - field_size = 40; + field_size = gf2n_short::DEFAULT_LENGTH; gf2n_short::init_field(field_size); setup_keys(); } @@ -62,11 +61,11 @@ PairwiseSetup& PairwiseMachine::setup() } template -void PairwiseMachine::setup_keys() +void RealPairwiseMachine::setup_keys() { auto& N = P; PairwiseSetup& s = setup(); - s.init(P, drown_sec, field_size, extra_slack); + s.init(P, sec, field_size, extra_slack); if (output) write_mac_key(get_prep_dir(P), P.my_num(), P.num_players(), s.alphai); for (auto& x : other_pks) @@ -83,10 +82,11 @@ void PairwiseMachine::setup_keys() if (i != N.my_num()) other_pks[i].unpack(os[i]); set_mac_key(s.alphai); + Share::MAC_Check::setup(P); } template -void PairwiseMachine::set_mac_key(T alphai) +void RealPairwiseMachine::set_mac_key(T alphai) { typedef typename T::FD FD; auto& N = P; @@ -141,5 +141,5 @@ void PairwiseMachine::check(Player& P) const bundle.compare(P); } -template void PairwiseMachine::setup_keys(); -template void PairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); diff --git a/FHEOffline/PairwiseMachine.h b/FHEOffline/PairwiseMachine.h index c2283443e..a8a0c649e 100644 --- a/FHEOffline/PairwiseMachine.h +++ b/FHEOffline/PairwiseMachine.h @@ -10,7 +10,7 @@ #include "FHEOffline/SimpleMachine.h" #include "FHEOffline/PairwiseSetup.h" -class PairwiseMachine : public MachineBase +class PairwiseMachine : public virtual MachineBase { public: PairwiseSetup setup_p; @@ -23,15 +23,6 @@ class PairwiseMachine : public MachineBase vector enc_alphas; PairwiseMachine(Player& P); - PairwiseMachine(int argc, const char** argv); - - void init(); - - template - void setup_keys(); - - template - void set_mac_key(T alphai); template PairwiseSetup& setup(); @@ -42,4 +33,18 @@ class PairwiseMachine : public MachineBase void check(Player& P) const; }; +class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine +{ +public: + RealPairwiseMachine(int argc, const char** argv); + + void init(); + + template + void setup_keys(); + + template + void set_mac_key(T alphai); +}; + #endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */ diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index bba83b5fd..bc890ed21 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -9,6 +9,7 @@ #include "Math/Setup.h" #include "FHEOffline/Proof.h" #include "FHEOffline/PairwiseMachine.h" +#include "FHEOffline/TemiSetup.h" #include "Tools/Commit.h" #include "Tools/Bundle.h" #include "Processor/OnlineOptions.h" @@ -53,7 +54,7 @@ void PairwiseSetup::init(const Player& P, int sec, int plaintext_length, template void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); alpha = FieldD; machine.sk = FHE_SK(params, FieldD.get_prime()); for (auto& pk : machine.other_pks) @@ -62,16 +63,20 @@ void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int pla template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec) + int plaintext_length, int sec, FHE_Params& params) { + assert(sec >= 0); machine.sec = sec; - sec = max(sec, 40); - machine.drown_sec = sec; + params.set_min_sec(sec); string filename = PREP_DIR + T::name() + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-" + + to_string(params.secp()) + "-" + + to_string(params.get_matrix_dim()) + "-" + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); + string reason; + try { ifstream file(filename); @@ -79,13 +84,30 @@ void secure_init(T& setup, Player& P, U& machine, os.input(file); os.get(machine.extra_slack); setup.unpack(os); + } + catch (exception& e) + { + reason = e.what(); + } + + try + { setup.check(P, machine); } - catch (...) + catch (exception& e) { - cout << "Finding parameters for security " << sec << " and field size ~2^" - << plaintext_length << endl; - setup.params = setup.params.n_mults(); + reason = e.what(); + } + + if (not reason.empty()) + { + if (OnlineOptions::singleton.verbose) + cerr << "Generating parameters for security " << sec + << " and field size ~2^" << plaintext_length + << " because no suitable material " + "from a previous run was found (" << reason << ")" + << endl; + setup = {}; setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; @@ -94,6 +116,14 @@ void secure_init(T& setup, Player& P, U& machine, ofstream file(filename); os.output(file); } + + if (OnlineOptions::singleton.verbose) + { + cerr << "Ciphertext length: " << params.p0().numBits(); + for (size_t i = 1; i < params.FFTD().size(); i++) + cerr << "+" << params.FFTD()[i].get_prime().numBits(); + cerr << endl; + } } template @@ -208,5 +238,8 @@ void PairwiseSetup::set_alphai(T alphai) template class PairwiseSetup; template class PairwiseSetup; -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); + +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); diff --git a/FHEOffline/PairwiseSetup.h b/FHEOffline/PairwiseSetup.h index 8e16eaf34..f6482edec 100644 --- a/FHEOffline/PairwiseSetup.h +++ b/FHEOffline/PairwiseSetup.h @@ -15,7 +15,7 @@ class MachineBase; template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec); + int plaintext_length, int sec, FHE_Params& params); template class PairwiseSetup diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp index 5714b7224..c3ab59ebf 100644 --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -167,6 +167,8 @@ string open_prep_file(ofstream& outf, string data_type, int my_num, int thread_n throw runtime_error("cannot create directory " + dir); string file = prep_filename(data_type, my_num, thread_num, initial, dir); outf.open(file.c_str(),ios::out | ios::binary | (clear ? ios::trunc : ios::app)); + if (clear) + file_signature>().output(outf); if (outf.fail()) { throw file_error(file); } return file; } @@ -516,6 +518,7 @@ InputProducer::InputProducer(const Player& P, int thread_num, if (thread_num) file << "-" << thread_num; outf[j].open(file.str().c_str(), ios::out | ios::binary); + file_signature>().output(outf[j]); if (outf[j].fail()) { throw file_error(file.str()); @@ -574,7 +577,7 @@ void InputProducer::run(const Player& P, const FHE_PK& pk, for (int j = min; j < max; j++) { AddableVector C; - vector> m(EC.machine->sec, FieldD); + vector> m(personal_EC.proof.U, FieldD); if (j == P.my_num()) { for (auto& x : m) diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h index 5e690b67c..2eec0435a 100644 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -22,6 +22,8 @@ enum SlackType class Proof { + protected: + unsigned int sec; bool diagonal; @@ -78,6 +80,7 @@ class Proof diagonal(diagonal), B_plain_length(0), B_rand_length(0), pk(&pk), n_proofs(n_proofs) { sec=sc; + assert(sec > 0); tau=Tau; rho=Rho; phim=(pk.get_params()).phi_m(); @@ -152,14 +155,18 @@ class Proof class NonInteractiveProof : public Proof { + // sec = 0 used for protocols without proofs + static int comp_sec(int sec) { return sec > 0 ? max(COMP_SEC, sec) : 0; } + public: bigint static slack(int sec, int phim) - { return bigint(phim * sec * sec) << (sec / 2 + 8); } + { sec = comp_sec(sec); return bigint(phim * sec * sec) << (sec / 2 + 8); } NonInteractiveProof(int sec, const FHE_PK& pk, int extra_slack, bool diagonal = false) : - Proof(sec, pk, 1, diagonal) + Proof(comp_sec(sec), pk, 1, diagonal) { + sec = this->sec; bigint B; B=128*sec*sec; B <<= extra_slack; diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp index d92f30806..7127b8c77 100644 --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -128,6 +128,7 @@ size_t Prover::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl bool ok=false; int cnt=0; + (void) cnt; while (!ok) { cnt++; Stage_1(P,ciphertexts,c,pk); diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp index dab3e507a..63d559d53 100644 --- a/FHEOffline/Sacrificing.cpp +++ b/FHEOffline/Sacrificing.cpp @@ -10,6 +10,8 @@ #include "Tools/Subroutines.h" +#include "Protocols/mac_key.hpp" + // The number of sacrifices to amortize at one time #define amortize 512 @@ -19,12 +21,7 @@ void Triple_Checking(const Player& P, MAC_Check& MC, int nm, int output_thread, TripleSacriFactory< Share >& factory, bool write_output, bool clear, string dir) { - if (T::length() < 40) - { - cerr << "Field too small for reasonable security" << endl; - cerr << "Use a larger field or remove this warning from " << __FILE__ << endl; - exit(1); - } + check_field_size(); ofstream outf; if (write_output) diff --git a/FHEOffline/SimpleDistDecrypt.cpp b/FHEOffline/SimpleDistDecrypt.cpp index 3774cd3c1..c8b923123 100644 --- a/FHEOffline/SimpleDistDecrypt.cpp +++ b/FHEOffline/SimpleDistDecrypt.cpp @@ -18,7 +18,12 @@ void SimpleDistDecrypt::reshare(Plaintext& EC) { (void)EC; + m = reshare(cm); +} +template +Plaintext_ SimpleDistDecrypt::reshare(const Ciphertext& cm) +{ PRNG G; G.ReSeed(); this->f.randomize(G, Full); @@ -27,10 +32,13 @@ void SimpleDistDecrypt::reshare(Plaintextrun(cm); // Step 4 + Plaintext_ m(this->f.get_field()); if (this->P.my_num()==0) { sub(m,this->mf,this->f); } else { m=this->f; m.negate(); } + + return m; } diff --git a/FHEOffline/SimpleDistDecrypt.h b/FHEOffline/SimpleDistDecrypt.h index 9589f15a1..c929a7990 100644 --- a/FHEOffline/SimpleDistDecrypt.h +++ b/FHEOffline/SimpleDistDecrypt.h @@ -20,6 +20,7 @@ class SimpleDistDecrypt : public DistDecrypt void reshare(Plaintext& m, const Ciphertext& cm, EncCommitBase& EC); + Plaintext_ reshare(const Ciphertext& cm); }; #endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */ diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp index 912920679..c161f1d7b 100644 --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -26,9 +26,10 @@ SimpleEncCommit::SimpleEncCommit(const PlayerBase& P, const FHE_PK& pk int thread_num, bool diagonal) : NonInteractiveProofSimpleEncCommit(P, pk, FTD, timers, machine, diagonal), - SimpleEncCommitFactory(pk, FTD, machine, diagonal) + SimpleEncCommitFactory(pk) { (void)thread_num; + this->init(this->proof, FTD); } template @@ -48,11 +49,15 @@ NonInteractiveProofSimpleEncCommit::NonInteractiveProofSimpleEncCommit( } template -SimpleEncCommitFactory::SimpleEncCommitFactory(const FHE_PK& pk, - const FD& FTD, const MachineBase& machine, bool diagonal) : +SimpleEncCommitFactory::SimpleEncCommitFactory(const FHE_PK& pk) : cnt(-1), n_calls(0), pk(pk) { - int sec = Proof::n_ciphertext_per_proof(machine.sec, pk, diagonal); +} + +template +void SimpleEncCommitFactory::init(const Proof& proof, const FD& FTD) +{ + int sec = proof.U; c.resize(sec, pk.get_params()); m.resize(sec, FTD); for (int i = 0; i < sec; i++) @@ -224,7 +229,7 @@ template SummingEncCommit::SummingEncCommit(const Player& P, const FHE_PK& pk, const FD& FTD, map& timers, const MachineBase& machine, int thread_num, bool diagonal) : - SimpleEncCommitFactory(pk, FTD, machine, diagonal), SimpleEncCommitBase_( + SimpleEncCommitFactory(pk), SimpleEncCommitBase_( machine), proof(machine.sec, pk, P.num_players(), diagonal), pk(pk), FTD( FTD), P(P), thread_num(thread_num), #ifdef LESS_ALLOC_MORE_MEM @@ -233,6 +238,7 @@ SummingEncCommit::SummingEncCommit(const Player& P, const FHE_PK& pk, #endif timers(timers) { + this->init(proof, FTD); } template diff --git a/FHEOffline/SimpleEncCommit.h b/FHEOffline/SimpleEncCommit.h index 9fccd9a75..f3034facf 100644 --- a/FHEOffline/SimpleEncCommit.h +++ b/FHEOffline/SimpleEncCommit.h @@ -92,9 +92,9 @@ class SimpleEncCommitFactory virtual void create_more() = 0; public: - SimpleEncCommitFactory(const FHE_PK& pk, const FD& FTD, - const MachineBase& machine, bool diagonal = false); + SimpleEncCommitFactory(const FHE_PK& pk); virtual ~SimpleEncCommitFactory(); + void init(const Proof& proof, const FD& FTD); bool has_left() { return cnt >= 0; } void next(Plaintext_& mess, Ciphertext& C); virtual size_t report_size(ReportType type); diff --git a/FHEOffline/SimpleGenerator.cpp b/FHEOffline/SimpleGenerator.cpp index b2701b2c5..be5ee2c19 100644 --- a/FHEOffline/SimpleGenerator.cpp +++ b/FHEOffline/SimpleGenerator.cpp @@ -12,7 +12,7 @@ template