From d626b390a84348f8e83041d6c57c05d56da05f0f Mon Sep 17 00:00:00 2001 From: VincBreaker Date: Mon, 1 Nov 2021 12:49:18 +0100 Subject: [PATCH 001/221] FIX documentaion of the CRASH instruction --- Compiler/instructions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 9d246b184..0cd9e6b59 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -442,7 +442,9 @@ class join_tape(base.Instruction): arg_format = ['int'] class crash(base.IOInstruction): - """ Crash runtime. """ + """ Crash runtime if the register's value is > 0. + + :param: Crash condition (regint)""" code = base.opcodes['CRASH'] arg_format = ['ci'] From 32950fe8d4ba4df6c9ba8b51afcf057c87250bd2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 4 Nov 2021 16:22:45 +1100 Subject: [PATCH 002/221] Maintenance. --- BMR/RealProgramParty.hpp | 2 - BMR/Register.h | 3 + CHANGELOG.md | 9 ++ CONFIG | 4 +- Compiler/GC/types.py | 2 + Compiler/allocator.py | 2 + Compiler/floatingpoint.py | 2 + Compiler/instructions.py | 109 +++++++++++++++++++++- Compiler/instructions_base.py | 12 ++- Compiler/ml.py | 136 +++++++++++++++++++++------ Compiler/program.py | 3 +- Compiler/types.py | 150 ++++++++++++++++++++++++------ ECDSA/fake-spdz-ecdsa-party.cpp | 4 +- ECDSA/mascot-ecdsa-party.cpp | 2 + ECDSA/ot-ecdsa-party.hpp | 2 - FHEOffline/PairwiseMachine.cpp | 1 + FHEOffline/Producer.cpp | 3 + GC/AtlasSecret.cpp | 7 -- GC/AtlasShare.h | 2 - GC/CcdPrep.h | 10 +- GC/CcdPrep.hpp | 1 + GC/CcdShare.h | 7 -- GC/FakeSecret.h | 2 + GC/MaliciousCcdShare.h | 7 -- GC/NoShare.h | 13 ++- GC/Processor.h | 2 +- GC/Processor.hpp | 20 ++-- GC/RepPrep.h | 1 - GC/RepPrep.hpp | 7 -- GC/RuntimeBranching.h | 5 + GC/Secret.h | 4 +- GC/Secret.hpp | 6 -- GC/SemiHonestRepPrep.h | 4 - GC/SemiPrep.cpp | 5 - GC/SemiPrep.h | 1 - GC/SemiSecret.h | 3 + GC/ShareSecret.h | 3 + GC/ShareThread.h | 7 +- GC/ShareThread.hpp | 33 +++---- GC/Thread.hpp | 6 +- GC/ThreadMaster.hpp | 10 ++ GC/TinierShare.h | 5 - GC/TinierSharePrep.h | 3 - GC/TinierSharePrep.hpp | 11 +-- GC/TinyPrep.hpp | 2 + GC/TinySecret.h | 22 +++++ GC/TinyShare.h | 7 -- Machines/TripleMachine.cpp | 2 +- Machines/emulate.cpp | 3 + Machines/real-bmr-party.cpp | 2 + Makefile | 29 +++--- Math/FixedVec.h | 5 + Math/Integer.h | 4 + Math/Integer.hpp | 15 +++ Math/gf2n.cpp | 8 ++ Math/gf2n.h | 2 + Math/gf2nlong.h | 23 +++++ Math/gfpvar.cpp | 8 +- Math/gfpvar.h | 2 + Networking/CryptoPlayer.cpp | 52 ++++++++--- Networking/CryptoPlayer.h | 1 - Networking/Player.h | 3 +- Networking/sockets.cpp | 4 +- Networking/ssl_sockets.h | 6 +- OT/NPartyTripleGenerator.hpp | 6 ++ Processor/DataPositions.cpp | 7 ++ Processor/Data_Files.h | 32 ++++++- Processor/Data_Files.hpp | 43 ++++++++- Processor/Input.h | 13 +++ Processor/InputTuple.h | 5 + Processor/Instruction.h | 11 +-- Processor/Instruction.hpp | 70 ++++++++------ Processor/Machine.h | 10 +- Processor/Machine.hpp | 21 ++--- Processor/Memory.h | 27 ++++-- Processor/Memory.hpp | 34 +++---- Processor/OfflineMachine.hpp | 8 +- Processor/Online-Thread.h | 2 +- Processor/Online-Thread.hpp | 8 +- Processor/OnlineMachine.hpp | 2 +- Processor/OnlineOptions.cpp | 95 ++++++++++--------- Processor/OnlineOptions.h | 2 + Processor/PrepBase.cpp | 7 -- Processor/Processor.h | 48 +++++----- Processor/Processor.hpp | 37 +++----- Processor/ProcessorBase.cpp | 15 ++- Processor/ProcessorBase.h | 7 +- Processor/ProcessorBase.hpp | 3 +- Processor/RingMachine.hpp | 2 +- Programs/Source/mnist_49.mpc | 1 + Programs/Source/mnist_full_A.mpc | 1 + Programs/Source/mnist_full_B.mpc | 1 + Programs/Source/mnist_full_C.mpc | 1 + Programs/Source/mnist_full_D.mpc | 1 + Programs/Source/tf.mpc | 2 + Protocols/Atlas.h | 4 + Protocols/AtlasPrep.h | 4 + Protocols/Beaver.h | 3 + Protocols/ChaiGearPrep.h | 3 + Protocols/CowGearPrep.h | 3 + Protocols/FakePrep.h | 10 ++ Protocols/FakeProtocol.h | 69 ++++++++++++++ Protocols/Hemi.h | 3 + Protocols/Hemi.hpp | 3 + Protocols/HemiMatrixPrep.h | 3 + Protocols/HemiPrep.h | 3 + Protocols/HemiShare.h | 1 + Protocols/HighGearKeyGen.h | 3 + Protocols/HighGearKeyGen.hpp | 4 + Protocols/LowGearKeyGen.h | 6 ++ Protocols/LowGearKeyGen.hpp | 3 + Protocols/MAC_Check.h | 18 ++++ Protocols/MAC_Check.hpp | 8 +- Protocols/MAC_Check_Base.h | 13 ++- Protocols/MalRepRingPrep.h | 7 ++ Protocols/MaliciousRepMC.h | 3 + Protocols/MaliciousRepPrep.h | 6 ++ Protocols/MaliciousShamirMC.h | 3 + Protocols/MamaPrep.h | 3 + Protocols/MascotPrep.h | 12 +++ Protocols/NoShare.h | 1 + Protocols/PostSacrifice.h | 3 + Protocols/Rep4.h | 3 + Protocols/RepRingOnlyEdabitPrep.h | 3 + Protocols/Replicated.h | 19 ++++ Protocols/ReplicatedInput.h | 6 ++ Protocols/ReplicatedMC.h | 3 + Protocols/ReplicatedPrep.h | 37 +++++++- Protocols/ReplicatedPrep.hpp | 43 ++++++--- Protocols/ReplicatedPrep2k.h | 3 + Protocols/RingOnlyPrep.h | 3 + Protocols/SPDZ.h | 3 + Protocols/Semi2k.h | 3 + Protocols/SemiInput.h | 3 + Protocols/SemiMC.h | 6 ++ Protocols/SemiPrep.h | 3 + Protocols/SemiPrep2k.h | 3 + Protocols/Shamir.h | 3 + Protocols/Shamir.hpp | 6 +- Protocols/ShamirInput.h | 7 ++ Protocols/ShamirMC.h | 6 ++ Protocols/ShuffleSacrifice.h | 3 + Protocols/SohoPrep.h | 3 + Protocols/Spdz2kPrep.h | 3 + Protocols/SpdzWise.h | 3 + Protocols/SpdzWiseInput.h | 3 + Protocols/SpdzWisePrep.h | 3 + Protocols/SpdzWiseRing.h | 3 + Protocols/SpdzWiseRingPrep.h | 4 + Protocols/SpdzWiseShare.h | 1 + Protocols/SpdzWiseShare.hpp | 6 ++ Protocols/dabit.h | 5 + Protocols/fake-stuff.h | 16 +++- Protocols/fake-stuff.hpp | 54 ++--------- README.md | 23 +++-- Scripts/run-common.sh | 2 +- Scripts/yao.sh | 14 +-- Tools/Buffer.cpp | 38 ++++++-- Tools/Buffer.h | 41 +++++++- Tools/Exceptions.cpp | 28 +++++- Tools/Exceptions.h | 23 +++-- Tools/aes-arm.h | 32 +++---- Tools/cpu_support.h | 7 ++ Utils/Check-Offline-Z2k.cpp | 3 + Utils/Check-Offline.cpp | 6 +- Utils/Fake-Offline.cpp | 53 +++-------- Utils/check-passive.cpp | 4 + Yao/YaoEvalWire.cpp | 30 +++++- Yao/YaoEvalWire.h | 1 + Yao/YaoEvaluator.cpp | 3 +- Yao/YaoGarbleWire.cpp | 36 ++++++- Yao/YaoGarbleWire.h | 2 + Yao/YaoGarbler.cpp | 9 ++ Yao/YaoPlayer.cpp | 58 ++---------- compile.py | 2 +- doc/Doxyfile | 2 +- doc/_static/custom.css | 3 + doc/add-protocol.rst | 6 +- doc/conf.py | 3 + doc/gen-instructions.py | 2 +- doc/instructions.rst | 1 - doc/low-level.rst | 148 +++++++++++++++++++++++++++++ doc/networking.rst | 2 + doc/non-linear.rst | 42 ++++++++- doc/troubleshooting.rst | 23 ++++- 185 files changed, 1818 insertions(+), 654 deletions(-) create mode 100644 doc/_static/custom.css diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 22438deb4..0c97f9bd8 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -112,8 +112,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : garble_processor.reset(program); this->processor.open_input_file(N.my_num(), 0); - T::bit_type::mac_key_type::init_field(); - GC::ShareThread share_thread(N, online_opts, *P, 0, usage); shared_proc = new SubProcessor(dummy_proc, *MC, *prep, *P); auto& inputter = shared_proc->input; diff --git a/BMR/Register.h b/BMR/Register.h index 886155d79..4d0c1b074 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -243,6 +243,9 @@ class Phase template static T get_input(int from, GC::Processor& processor, int n_bits) { return T::input(from, processor.get_input(n_bits), n_bits); } + template + static void reveal_inst(GC::Processor& processor, const vector& args) + { processor.reveal(args); } template static void convcbit(Integer& dest, const GC::Clear& source, T&) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e23aaa68..8c9be9e5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.8 (Nov 4, 2021) + +- Tested on Apple laptop with ARM chip +- Restore trusted client interface +- Directly accessible softmax function +- Signature in preprocessing files to reduce confusing errors +- Improved error messages for connection issues +- Documentation of low-level share types and protocol pairs + ## 0.2.7 (Sep 17, 2021) - Optimized matrix multiplication in Hemi diff --git a/CONFIG b/CONFIG index 5f12d2c33..ba6855ea9 100644 --- a/CONFIG +++ b/CONFIG @@ -88,7 +88,9 @@ CPPFLAGS = $(CFLAGS) LD = $(CXX) ifeq ($(OS), Darwin) +# for boost with OpenSSL 3 +CFLAGS += -Wno-error=deprecated-declarations ifeq ($(USE_NTL),1) -CFLAGS += -Wno-error=unused-parameter +CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy endif endif diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index a1384475b..6a5e39f13 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -309,6 +309,8 @@ def __invert__(self): res = type(self)() inst.notcb(self.n, res, self) return res + def __eq__(self, other): + raise CompilerError('equality not implemented') def print_reg(self, desc=''): inst.print_regb(self, desc) def print_reg_plain(self): diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 6a1472d7b..7ce9896b1 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -19,6 +19,8 @@ def __init__(self): self.by_address = {} def by_size(self, size): + if size >= 2 ** 32: + raise CompilerError('size exceeds addressing capability') return self.by_logsize[int(math.log(size, 2))][size] def push(self, address, size): diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 66de0859a..a15a62dd9 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -290,6 +290,7 @@ def BitDecRing(a, k, m): return [types.sint.conv(bit) for bit in reversed(bits)][::-1] def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): + instructions_base.set_global_vector_size(a.size) r_dprime = types.sint() r_prime = types.sint() c = types.cint() @@ -298,6 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): pow2 = two_power(k + kappa) asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) + instructions_base.reset_global_vector_size() return res def BitDecField(a, k, m, kappa, bits_to_compute=None): diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 0cd9e6b59..8069b0ffb 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1496,6 +1496,14 @@ class cond_print_plain(base.IOInstruction): code = base.opcodes['CONDPRINTPLAIN'] arg_format = ['c', 'c', 'c'] + def __init__(self, *args, **kwargs): + base.Instruction.__init__(self, *args, **kwargs) + self.size = args[1].size + args[2].set_size(self.size) + + def get_code(self): + return base.Instruction.get_code(self, self.size) + class print_int(base.IOInstruction): """ Output clear integer register. @@ -1591,7 +1599,16 @@ def has_var_args(self): return True class readsockets(base.IOInstruction): - """Read a variable number of secret shares + MACs from socket for a client id and store in registers""" + """ Read a variable number of secret shares (potentially with MAC) + from a socket for a client id and store them in registers. If the + protocol uses MACs, the client should be different for every party. + + :param: client id (regint) + :param: vector size (int) + :param: source (sint) + :param: (repeat source)... + + """ __slots__ = [] code = base.opcodes['READSOCKETS'] arg_format = tools.chain(['ci','int'], itertools.repeat('sw')) @@ -1628,6 +1645,25 @@ class writesocketc(base.IOInstruction): def has_var_args(self): return True +class writesockets(base.IOInstruction): + """ Write a variable number of secret shares (potentially with MAC) + from registers into a socket for a specified client id. If the + protocol uses MACs, the client should be different for every party. + + :param: client id (regint) + :param: message type (must be 0) + :param: vector size (int) + :param: source (sint) + :param: (repeat source)... + + """ + __slots__ = [] + code = base.opcodes['WRITESOCKETS'] + arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s')) + + def has_var_args(self): + return True + class writesocketshare(base.IOInstruction): """ Write a variable number of shares (without MACs) from secret registers into socket for a specified client id. @@ -2384,5 +2420,76 @@ def expand(self): subs(a, self.args[1], self.args[2]) comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) +# placeholder for documentation +class cisc: + """ Meta instruction for emulation. This instruction is only generated + when using ``-K`` with ``compile.py``. The header looks as follows: + + :param: number of arguments after name plus one + :param: name (16 bytes, zero-padded) + + Currently, the following names are supported: + + LTZ + Less than zero. + + :param: number of arguments in this unit (must be 6) + :param: vector size + :param: result (sint) + :param: input (sint) + :param: bit length + :param: (ignored) + :param: (repeat)... + + Trunc + Truncation. + + :param: number of arguments in this unit (must be 8) + :param: vector size + :param: result (sint) + :param: input (sint) + :param: bit length + :param: number of bits to truncate + :param: (ignored) + :param: 0 for unsigned or 1 for signed + :param: (repeat)... + + FPDiv + Fixed-point division. Division by zero results in zero without error. + + :param: number of arguments in this unit (must be at least 7) + :param: vector size + :param: result (sint) + :param: dividend (sint) + :param: divisor (sint) + :param: (ignored) + :param: fixed-point precision + :param: (repeat)... + + exp2_fx + Fixed-point power of two. + + :param: number of arguments in this unit (must be at least 6) + :param: vector size + :param: result (sint) + :param: exponent (sint) + :param: (ignored) + :param: fixed-point precision + :param: (repeat)... + + log2_fx + Fixed-point logarithm with base 2. + + :param: number of arguments in this unit (must be at least 6) + :param: vector size + :param: result (sint) + :param: input (sint) + :param: (ignored) + :param: fixed-point precision + :param: (repeat)... + + """ + code = base.opcodes['CISC'] + # hack for circular dependency from Compiler import comparison diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 66b55fd9d..38fd97d29 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -481,7 +481,16 @@ def new_instructions(self, size, regs): def expand_merged(self, skip): if function.__name__ in skip: - return [self], 0 + good = True + for call in self.calls: + if not good: + break + for arg in call[0]: + if isinstance(arg, program.curr_tape.Register) and \ + not issubclass(type(self.calls[0][0][0]), type(arg)): + good = False + if good: + return [self], 0 tape = program.curr_tape block = tape.BasicBlock(tape, None, None) tape.active_basicblock = block @@ -520,6 +529,7 @@ def get_bytes(self): String.check(name) res += String.encode(name) for call in self.calls: + call[1].pop('nearest', None) assert not call[1] res += int_to_bytes(len(call[0]) + 2) res += int_to_bytes(call[0][0].size) diff --git a/Compiler/ml.py b/Compiler/ml.py index 42389ae83..a8be2d533 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -155,6 +155,26 @@ def op(a, b): return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1]) return tree_reduce(op, enumerate(x))[0] +def softmax(x): + """ Softmax. + + :param x: vector or list of sfix + :returns: sfix vector + """ + return softmax_from_exp(exp_for_softmax(x)[0]) + +def exp_for_softmax(x): + m = util.max(x) + mv = m.expand_to_vector(len(x)) + try: + x = x.get_vector() + except AttributeError: + x = sfix(x) + return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m + +def softmax_from_exp(x): + return x / sum(x) + report_progress = False def progress(x): @@ -464,10 +484,7 @@ def _(i): self.losses[i] = -sfix.dot_product( self.Y[batch[i]].get_vector(), log_e(div)) else: - m = util.max(self.X[i]) - mv = m.expand_to_vector(d_out) - x = self.X[i].get_vector() - e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0) + e, m = exp_for_softmax(self.X[i]) self.exp[i].assign_vector(e) if self.compute_loss: true_X = sfix.dot_product(self.Y[batch[i]], self.X[i]) @@ -532,11 +549,8 @@ def _(j): return @for_range_opt_multithread(self.n_threads, len(batch)) def _(i): - for j in range(d_out): - dividend = self.exp[i][j] - divisor = sum(self.exp[i]) - div = (divisor > 0.1).if_else(dividend / divisor, 0) - self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div) + div = softmax_from_exp(self.exp[i]) + self.nabla_X[i][:] = -self.Y[batch[i]][:] + div self.maybe_debug_backward(batch) def maybe_debug_backward(self, batch): @@ -588,7 +602,7 @@ class DenseBase(Layer): nablas = lambda self: (self.nabla_W, self.nabla_b) def output_weights(self): - print_ln('%s', self.W.reveal_nested()) + self.W.print_reveal_nested() print_ln('%s', self.b.reveal_nested()) def backward_params(self, f_schur_Y, batch): @@ -1316,7 +1330,7 @@ def input_from(self, player, raw=False): self.bias.input_from(player, raw=raw) def output_weights(self): - print_ln('%s', self.weights.reveal_nested()) + self.weights.print_reveal_nested() print_ln('%s', self.bias.reveal_nested()) def dot_product(self, iv, wv, out_y, out_x, out_c): @@ -1942,7 +1956,12 @@ def _(_): for label, X in enumerate(self.X_by_label): indices = regint.Array(n * n_per_epoch) indices_by_label.append(indices) - indices.assign(regint.inc(len(indices), 0, 1, 1, len(X))) + indices.assign(regint.inc(len(X))) + missing = len(indices) - len(X) + if missing: + indices.assign_vector( + regint.get_random(int(math.log2(len(X))), size=missing), + base=len(X)) if self.always_shuffle or n_per_epoch > 1: indices.shuffle() loss_sum = MemValue(sfix(0)) @@ -2050,6 +2069,8 @@ def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, self.time_layers = 'time_layers' in program.args self.revealing_correctness = not 'no_acc' in program.args self.layers[-1].compute_loss = not 'no_loss' in program.args + if 'full_cisc' in program.args: + program.options.keep_cisc = 'FPDiv,exp2_fx,log2_fx' model_input = 'model_input' in program.args acc_first = model_input and not 'train_first' in program.args if model_input: @@ -2058,12 +2079,14 @@ def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, else: self.reset() if 'one_iter' in program.args: + print_float_prec(16) self.output_weights() print_ln('loss') - print_ln('%s', self.eval( - self.layers[0].X.get_part(0, batch_size)).reveal_nested()) + self.eval( + self.layers[0].X.get_part(0, batch_size), + batch_size=batch_size).print_reveal_nested() for layer in self.layers: - print_ln('%s', layer.X.get_part(0, batch_size).reveal_nested()) + layer.X.get_part(0, batch_size).print_reveal_nested() print_ln('%s', self.layers[-1].Y.get_part(0, batch_size).reveal_nested()) batch = Array.create_from(regint.inc(batch_size)) self.forward(batch=batch, training=True) @@ -2083,9 +2106,10 @@ def _(i): return N = self.layers[0].X.sizes[0] n_trained = (N + batch_size - 1) // batch_size * batch_size - print_ln('train_acc: %s (%s/%s)', - cfix(self.n_correct, k=63, f=31) / n_trained, - self.n_correct, n_trained) + if not acc_first: + print_ln('train_acc: %s (%s/%s)', + cfix(self.n_correct, k=63, f=31) / n_trained, + self.n_correct, n_trained) if test_X and test_Y: n_test = len(test_Y) n_correct, loss = self.reveal_correctness(test_X, test_Y, @@ -2500,16 +2524,30 @@ def predict(self, x, batch_size=None): batch_size = min(batch_size, self.batch_size) return self.opt.eval(x, batch_size=batch_size) -def solve_linear(A, b, n_iterations, progress=False): - """ Iterative linear solution approximation. """ +def solve_linear(A, b, n_iterations, progress=False, n_threads=None, + stop=False, already_symmetric=False, precond=False): + """ Iterative linear solution approximation for :math:`Ax=b`. + + :param progress: print some information on the progress (implies revealing) + :param n_threads: number of threads to use + :param stop: whether to stop when converged (implies revealing) + + """ assert len(b) == A.sizes[0] x = sfix.Array(A.sizes[1]) x.assign_vector(sfix.get_random(-1, 1, size=len(x))) - AtA = sfix.Matrix(len(x), len(x)) - AtA[:] = A.direct_trans_mul(A) + if already_symmetric: + AtA = A + r = Array.create_from(b - AtA * x) + else: + AtA = sfix.Matrix(len(x), len(x)) + A.trans_mul_to(A, AtA, n_threads=n_threads) + r = Array.create_from(A.transpose() * b - AtA * x) + if precond: + return solve_linear_diag_precond(AtA, b, x, r, n_iterations, + progress, stop) v = sfix.Array(A.sizes[1]) v.assign_all(0) - r = Array.create_from(A.transpose() * b - AtA * x) Av = sfix.Array(len(x)) @for_range(n_iterations) def _(i): @@ -2523,10 +2561,43 @@ def _(i): if progress: print_ln('%s alpha=%s vr=%s v_norm=%s', i, alpha.reveal(), vr.reveal(), v_norm.reveal()) + if stop: + return (alpha > 0).reveal() return x -def mr(A, n_iterations): - """ Iterative matrix inverse approximation. """ +def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False, + stop=False): + m = 1 / A.diag() + mr = Array.create_from(m * r[:]) + d = Array.create_from(mr) + @for_range(n_iterations) + def _(i): + Ad = A * d + d_norm = sfix.dot_product(d, Ad) + alpha = (d_norm == 0).if_else(0, sfix.dot_product(r, mr) / d_norm) + x[:] = x[:] + alpha * d[:] + r_norm = sfix.dot_product(r, mr) + r[:] = r[:] - alpha * Ad + tmp = m * r[:] + beta = (r_norm == 0).if_else(0, sfix.dot_product(r, tmp) / r_norm) + mr[:] = tmp + d[:] = tmp + beta * d + if progress: + print_ln('%s alpha=%s beta=%s r_norm=%s d_norm=%s', i, + alpha.reveal(), beta.reveal(), r_norm.reveal(), + d_norm.reveal()) + if stop: + return (alpha > 0).reveal() + return x + +def mr(A, n_iterations, stop=False): + """ Iterative matrix inverse approximation. + + :param A: matrix to invert + :param n_iterations: maximum number of iterations + :param stop: whether to stop when converged (implies revealing) + + """ assert len(A.sizes) == 2 assert A.sizes[0] == A.sizes[1] M = A.same_shape() @@ -2536,5 +2607,18 @@ def _(i): e = sfix.Array(n) e.assign_all(0) e[i] = 1 - M[i] = solve_linear(A, e, n_iterations) + M[i] = solve_linear(A, e, n_iterations, stop=stop) return M.transpose() + +def var(x): + """ Variance. """ + mean = MemValue(type(x[0])(0)) + @for_range_opt(len(x)) + def _(i): + mean.iadd(x[i]) + mean /= len(x) + res = MemValue(type(x[0])(0)) + @for_range_opt(len(x)) + def _(i): + res.iadd((x[i] - mean.read()) ** 2) + return res.read() diff --git a/Compiler/program.py b/Compiler/program.py index ebc1601c8..19ce52480 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -690,8 +690,9 @@ def add_usage(self, req_node): def expand_cisc(self): new_instructions = [] - if self.parent.program.options.keep_cisc: + if self.parent.program.options.keep_cisc != None: skip = ['LTZ', 'Trunc'] + skip += self.parent.program.options.keep_cisc.split(',') else: skip = [] for inst in self.instructions: diff --git a/Compiler/types.py b/Compiler/types.py index 917eba752..151c805ea 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1161,7 +1161,7 @@ def print_if(self, string): cond_print_str(self, string) def output_if(self, cond): - cond_print_plain(self.conv(cond), self, cint(0)) + cond_print_plain(self.conv(cond), self, cint(0, size=self.size)) class cgf2n(_clear, _gf2n): @@ -1922,14 +1922,7 @@ def load_other(self, val): r = self.get_dabit() movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit())) elif isinstance(val, sbitvec): - assert(sum(x.n for x in val.v) == self.size) - for val_part, base in zip(val, range(0, self.size, 64)): - left = min(64, self.size - base) - r = self.get_dabit(size=left) - v = regint(size=left) - bitdecint_class(regint((r[1] ^ val_part).reveal()), *v) - part = r[0].bit_xor(v) - vmovs(left, self.get_vector(base, left), part) + movs(self, sint.bit_compose(val)) else: self.load_clear(self.clear_type(val)) @@ -1939,6 +1932,8 @@ def bit_compose(cls, bits): :param bits: iterable of any type convertible to sint """ from Compiler.GC.types import sbits, sbitintvec + if isinstance(bits, sbits): + bits = bits.bit_decompose() bits = list(bits) if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits): if program.use_edabit(): @@ -2112,6 +2107,11 @@ class sint(_secret, _int): thereof or sbits/sbitvec/sfix) :param size: vector size (int), defaults to 1 or size of list + When converting :py:class:`~Compiler.GC.types.sbits`, the result is a + vector of bits, and when converting + :py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values + with bit length equal the length of the input. + """ __slots__ = [] instruction_type = 'modp' @@ -2278,6 +2278,17 @@ def read_from_socket(cls, client_id, n=1): else: return res + @vectorized_classmethod + def write_to_socket(cls, client_id, values, + message_type=ClientMessageType.NoType): + """ Send a list of shares and MAC shares to a client socket. + + :param client_id: regint + :param values: list of sint + + """ + writesockets(client_id, message_type, values[0].size, *values) + @vectorize def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType): """ Send only share to socket """ @@ -2339,6 +2350,7 @@ def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None): @vectorize_init def __init__(self, val=None, size=None): + from .GC.types import sbitvec if isinstance(val, personal): size = val._v.size super(sint, self).__init__('s', size=size) @@ -2346,6 +2358,8 @@ def __init__(self, val=None, size=None): elif isinstance(val, _fix): super(sint, self).__init__('s', size=val.v.size) self.load_other(val.v.round(val.k, val.f)) + elif isinstance(val, sbitvec): + super(sint, self).__init__('s', val=val, size=val[0].n) else: super(sint, self).__init__('s', val=val, size=size) @@ -3747,13 +3761,14 @@ def __rtruediv__(self, other): other = parse_type(other, self.k, self.f) return other / self + @vectorize def print_plain(self): """ Clear fixed-point output. """ print_float_plain(cint.conv(self.v), cint(-self.f), \ cint(0), cint(0), cint(0)) def output_if(self, cond): - cond_print_plain(cint.conv(cond), self.v, cint(-self.f)) + cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size)) @vectorize def binary_output(self, player=None): @@ -4028,7 +4043,8 @@ def set_precision_from_args(cls, program, adapt_ring=False): print('Nearest rounding instead of proabilistic ' 'for fixed-point computation') cls.round_nearest = True - if adapt_ring and program.options.ring: + if adapt_ring and program.options.ring \ + and 'fix_ring' not in program.args: need = 2 ** int(math.ceil(math.log(2 * cls.k, 2))) if need != int(program.options.ring): print('Changing computation modulus to 2^%d' % need) @@ -4094,6 +4110,8 @@ def __init__(self, _v=None, k=None, f=None, size=None): elif isinstance(_v, (MemValue, MemFix)): #this is a memvalue object self.v = type(self)(_v.read()).v + elif isinstance(_v, (list, tuple)): + self.v = self.int_type(list(self.conv(x).v for x in _v)) else: raise CompilerError('cannot convert %s to sfix' % _v) if not isinstance(self.v, self.int_type): @@ -4250,7 +4268,7 @@ def get_raw_input_from(cls, player): return cls._new(cls.int_type.get_raw_input_from(player)) @vectorized_classmethod - def get_random(cls, lower, upper): + def get_random(cls, lower, upper, symmetric=True): """ Uniform secret random number around centre of bounds. Actual range can be smaller but never larger. @@ -4261,8 +4279,19 @@ def get_random(cls, lower, upper): log_range = int(math.log(upper - lower, 2)) n_bits = log_range + cls.f average = lower + 0.5 * (upper - lower) - lower = average - 0.5 * 2 ** log_range - return cls._new(cls.int_type.get_random_int(n_bits)) + lower + real_range = (2 ** (n_bits) - 1) / 2 ** cls.f + lower = average - 0.5 * real_range + real_lower = round(lower * 2 ** cls.f) / 2 ** cls.f + r = cls._new(cls.int_type.get_random_int(n_bits)) + lower + if symmetric: + lowest = math.floor(lower * 2 ** cls.f) / 2 ** cls.f + print('randomness range [%f,%f], fringes half the probability' % \ + (lowest, lowest + 2 ** log_range)) + return cls.int_type.get_random_bit().if_else(r, -r + 2 * average) + else: + print('randomness range [%f,%f], %d bits' % \ + (real_lower, real_lower + real_range, n_bits)) + return r @classmethod def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None): @@ -4985,6 +5014,7 @@ class cfloat(object): """ Helper class for printing revealed sfloats. """ __slots__ = ['v', 'p', 'z', 's', 'nan'] + @vectorize_init def __init__(self, v, p=None, z=None, s=None, nan=0): """ Parameters as with :py:class:`sfloat` but public. """ if s is None: @@ -4993,6 +5023,11 @@ def __init__(self, v, p=None, z=None, s=None, nan=0): parts = [cint.conv(x) for x in (v, p, z, s, nan)] self.v, self.p, self.z, self.s, self.nan = parts + @property + def size(self): + return self.v.size + + @vectorize def print_float_plain(self): """ Output. """ print_float_plain(self.v, self.p, self.z, self.s, self.nan) @@ -5389,15 +5424,7 @@ def __neg__(self): def shuffle(self): """ Insecure shuffle in place. """ - if self.value_type == regint: - self.assign(self.get_vector().shuffle()) - else: - @library.for_range(len(self)) - def _(i): - j = regint.get_random(64) % (len(self) - i) - tmp = self[i] - self[i] = self[i + j] - self[i + j] = tmp + self.assign_vector(self.get(regint.inc(len(self)).shuffle())) def reveal(self): """ Reveal the whole array. @@ -5411,6 +5438,13 @@ def reveal_list(self): reveal_nested = reveal_list + def print_reveal_nested(self, end='\n'): + """ Reveal and print as list. + + :param end: string to print after (default: line break) + """ + library.print_str('%s' + end, self.get_vector().reveal()) + def reveal_to_binary_output(self, player=None): """ Reveal to binary output if supported by type. @@ -5449,6 +5483,8 @@ def __str__(self): class SubMultiArray(_vectorizable): """ Multidimensional array functionality. Don't construct this directly, use :py:class:`MultiArray` instead. """ + check_indices = True + def __init__(self, sizes, value_type, address, index, debug=None): self.sizes = tuple(sizes) self.value_type = _get_type(value_type) @@ -5458,7 +5494,6 @@ def __init__(self, sizes, value_type, address, index, debug=None): self.address = None self.sub_cache = {} self.debug = debug - self.check_indices = True if debug: library.print_ln_if(self.address + reduce(operator.mul, self.sizes) * self.value_type.n_elements() > program.allocated_mem[self.value_type.reg_type], 'AOF%d:' % len(self.sizes) + self.debug) @@ -5467,8 +5502,6 @@ def __getitem__(self, index): :param index: public (regint/cint/int) :return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise""" - if util.is_constant(index) and index >= self.sizes[0]: - raise StopIteration if isinstance(index, slice) and index == slice(None): return self.get_vector() key = program.curr_block, str(index) @@ -5490,7 +5523,9 @@ def __getitem__(self, index): self.sub_cache[key] = \ SubMultiArray(self.sizes[1:], self.value_type, \ self.address, index, debug=self.debug) - return self.sub_cache[key] + res = self.sub_cache[key] + res.check_indices = self.check_indices + return res def __setitem__(self, index, other): """ Part assignment. @@ -5505,6 +5540,9 @@ def __len__(self): """ Size of top dimension. """ return self.sizes[0] + def __iter__(self): + return (self[i] for i in range(len(self))) + def assign_all(self, value): """ Assign the same value to all entries. @@ -5877,6 +5915,38 @@ def direct_trans_mul(self, other, reduce=True, indices=None): self.address, other.address, None, 1, other.sizes[1], reduce=reduce, indices=indices) + def trans_mul_to(self, other, res, n_threads=None): + """ + Matrix multiplication with the transpose of :py:obj:`self` + in the virtual machine. + + :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param res: matrix of matching dimension to store result + :param n_threads: number of threads (default: single thread) + """ + @library.for_range_multithread(n_threads, 1, self.sizes[1]) + def _(i): + indices = [regint(i), regint.inc(self.sizes[0])] + indices += [regint.inc(i) for i in other.sizes] + res[i] = self.direct_trans_mul(other, indices=indices) + + def mul_trans_to(self, other, res, n_threads=None): + """ + Matrix multiplication with the transpose of :py:obj:`other` + in the virtual machine. + + :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param res: matrix of matching dimension to store result + :param n_threads: number of threads (default: single thread) + """ + @library.for_range_multithread(n_threads, 1, self.sizes[0]) + def _(i): + indices = [regint(i), regint.inc(self.sizes[1])] + indices += [regint.inc(i) for i in reversed(other.sizes)] + res[i] = self.direct_mul_trans(other, indices=indices) + def direct_mul_to_matrix(self, other): """ Matrix multiplication in the virtual machine. @@ -5992,6 +6062,13 @@ def trace(self): assert self.sizes[0] == self.sizes[1] return sum(self[i][i] for i in range(self.sizes[0])) + def diag(self): + """ Matrix diagonal. """ + assert len(self.sizes) == 2 + assert self.sizes[0] == self.sizes[1] + n = self.sizes[0] + return self.array.get(regint.inc(n, 0, n + 1)) + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6007,6 +6084,21 @@ def f(sizes): return [f(sizes[1:]) for i in range(sizes[0])] return f(self.sizes) + def print_reveal_nested(self, end='\n'): + """ Reveal and print as nested list. + + :param end: string to print after (default: line break) + """ + if self.total_size() < program.options.budget: + library.print_str('%s' + end, self.reveal_nested()) + else: + library.print_str('[') + @library.for_range(len(self) - 1) + def _(i): + self[i].print_reveal_nested(end=', ') + self[len(self) - 1].print_reveal_nested(end='') + library.print_str(']' + end) + def reveal_to_binary_output(self, player=None): """ Reveal to binary output if supported by type. @@ -6042,6 +6134,10 @@ class MultiArray(SubMultiArray): a[2][:] = a[0][:] * a[1][:] """ + @staticmethod + def disable_index_checks(): + SubMultiArray.check_indices = False + def __init__(self, sizes, value_type, debug=None, address=None, alloc=True): if isinstance(address, Array): self.array = address diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index 6cc2fbcc8..f0e3257c6 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -3,6 +3,8 @@ * */ +#define NO_MIXED_CIRCUITS + #include "Networking/Server.h" #include "Networking/CryptoPlayer.h" #include "Math/gfp.h" @@ -49,8 +51,6 @@ int main(int argc, const char** argv) ArithmeticProcessor _({}, 0); BaseMachine machine; machine.ot_setups.push_back({P, false}); - GC::ShareThread thread(N, - OnlineOptions::singleton, P, {}, usage); SubProcessor proc(_, MCp, prep, P); pShare sk, __; diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp index dc2edab31..87573593b 100644 --- a/ECDSA/mascot-ecdsa-party.cpp +++ b/ECDSA/mascot-ecdsa-party.cpp @@ -3,6 +3,8 @@ * */ +#define NO_MIXED_CIRCUITS + #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 1c4a442ae..58f35d4b0 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -106,8 +106,6 @@ void run(int argc, const char** argv) typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); typename pShare::TriplePrep sk_prep(0, usage); - GC::ShareThread thread(N, - OnlineOptions::singleton, P, {}, usage); SubProcessor sk_proc(_, MCp, sk_prep, P); pShare sk, __; // synchronize diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index 6ac3a82d4..e41fe1837 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -6,6 +6,7 @@ #include "FHEOffline/PairwiseMachine.h" #include "Tools/benchmarking.h" #include "Protocols/fake-stuff.h" +#include "Tools/Bundle.h" #include "Protocols/fake-stuff.hpp" diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp index 5714b7224..9b143dd66 100644 --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -167,6 +167,8 @@ string open_prep_file(ofstream& outf, string data_type, int my_num, int thread_n throw runtime_error("cannot create directory " + dir); string file = prep_filename(data_type, my_num, thread_num, initial, dir); outf.open(file.c_str(),ios::out | ios::binary | (clear ? ios::trunc : ios::app)); + if (clear) + file_signature>().output(outf); if (outf.fail()) { throw file_error(file); } return file; } @@ -516,6 +518,7 @@ InputProducer::InputProducer(const Player& P, int thread_num, if (thread_num) file << "-" << thread_num; outf[j].open(file.str().c_str(), ios::out | ios::binary); + file_signature>().output(outf[j]); if (outf[j].fail()) { throw file_error(file.str()); diff --git a/GC/AtlasSecret.cpp b/GC/AtlasSecret.cpp index 2c396bfbc..92f290f31 100644 --- a/GC/AtlasSecret.cpp +++ b/GC/AtlasSecret.cpp @@ -24,11 +24,4 @@ AtlasShare::AtlasShare(const AtlasSecret& other) : { } -void AtlasShare::random() -{ - AtlasSecret tmp; - this->get_party().DataF.get_one(DATA_BIT, tmp); - *this = tmp.get_reg(0); -} - } diff --git a/GC/AtlasShare.h b/GC/AtlasShare.h index bad9e10eb..4b68dd4f3 100644 --- a/GC/AtlasShare.h +++ b/GC/AtlasShare.h @@ -63,8 +63,6 @@ class AtlasShare : public ::AtlasShare>, public ShareSecret { typename T::part_type::LivePrep part_prep; SubProcessor* part_proc; - ShareThread& thread; public: - CcdPrep(DataPositions& usage, ShareThread& thread) : - BufferPrep(usage), part_prep(usage, thread), part_proc(0), - thread(thread) + static const bool use_part = true; + + CcdPrep(DataPositions& usage) : + BufferPrep(usage), part_prep(usage), part_proc(0) { } CcdPrep(SubProcessor*, DataPositions& usage) : - CcdPrep(usage, ShareThread::s()) + CcdPrep(usage) { } diff --git a/GC/CcdPrep.hpp b/GC/CcdPrep.hpp index 62f0f097e..f9535350b 100644 --- a/GC/CcdPrep.hpp +++ b/GC/CcdPrep.hpp @@ -23,6 +23,7 @@ CcdPrep::~CcdPrep() template void CcdPrep::set_protocol(typename T::Protocol& protocol) { + auto& thread = ShareThread::s(); assert(thread.MC); part_proc = new SubProcessor( thread.MC->get_part_MC(), part_prep, protocol.get_part().P); diff --git a/GC/CcdShare.h b/GC/CcdShare.h index cd6c3618f..aececad0a 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -74,13 +74,6 @@ class CcdShare : public ShamirShare, public ShareSecret> *this = input; } - void random() - { - CcdSecret tmp; - ShareThread>::s().DataF.get_one(DATA_BIT, tmp); - *this = tmp.get_reg(0); - } - This& operator^=(const This& other) { *this += other; diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index e3820e06d..73013efa5 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -136,6 +136,8 @@ class FakeSecret : public ShareInterface, public BitVec void andrs(int n, const FakeSecret& x, const FakeSecret& y) { *this = BitVec(x.a * (y.a & 1)).mask(n); } + void xor_bit(int i, FakeSecret bit) { *this ^= bit << i; } + void invert(int n, const FakeSecret& x) { *this = BitVec(~x.a).mask(n); } void random_bit() { a = random() % 2; } diff --git a/GC/MaliciousCcdShare.h b/GC/MaliciousCcdShare.h index 6f9410961..9dc63fc63 100644 --- a/GC/MaliciousCcdShare.h +++ b/GC/MaliciousCcdShare.h @@ -79,13 +79,6 @@ class MaliciousCcdShare: public MaliciousShamirShare, public ShareSecret< *this = input; } - void random() - { - MaliciousCcdSecret tmp; - ShareThread>::s().DataF.get_one(DATA_BIT, tmp); - *this = tmp.get_reg(0); - } - This& operator^=(const This& other) { *this += other; diff --git a/GC/NoShare.h b/GC/NoShare.h index 9bffccd5e..f60eccd75 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -7,12 +7,12 @@ #define GC_NOSHARE_H_ #include "Processor/DummyProtocol.h" -#include "BMR/Register.h" -#include "Tools/SwitchableOutput.h" #include "Protocols/ShareInterface.h" class InputArgs; class ArithmeticProcessor; +class BlackHole; +class SwitchableOutput; namespace GC { @@ -110,7 +110,7 @@ class NoShare : public ShareInterface typedef NoShare small_type; - typedef BlackHole out_type; + typedef SwitchableOutput out_type; static const bool is_real = false; @@ -124,6 +124,11 @@ class NoShare : public ShareInterface return "no"; } + static void specification(octetStream&) + { + fail(); + } + static int size() { return 0; @@ -172,6 +177,8 @@ class NoShare : public ShareInterface NoShare get_bit(int) const { fail(); return {}; } + void xor_bit(int, NoShare) const { fail(); } + void invert(int, NoShare) { fail(); } NoShare mask(int) const { fail(); return {}; } diff --git a/GC/Processor.h b/GC/Processor.h index e759cf05f..2dddf8df2 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -27,7 +27,7 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching static int check_args(const vector& args, int n); template - static void check_input(const U& in, int n_bits); + static void check_input(const U& in, const int* params); Machine* machine; Memories& memories; diff --git a/GC/Processor.hpp b/GC/Processor.hpp index b1539b04f..016352d3e 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -89,15 +89,15 @@ U GC::Processor::get_long_input(const int* params, else res = input_proc.get_input>(interactive, ¶ms[1]).items[0]; - int n_bits = *params; - check_input(res, n_bits); + check_input(res, params); return res; } template template -void GC::Processor::check_input(const U& in, int n_bits) +void GC::Processor::check_input(const U& in, const int* params) { + int n_bits = *params; auto test = in >> (n_bits - 1); if (n_bits == 1) { @@ -106,9 +106,17 @@ void GC::Processor::check_input(const U& in, int n_bits) } else if (not (test == 0 or test == -1)) { - throw runtime_error( - "input too large for a " + std::to_string(n_bits) - + "-bit signed integer: " + to_string(in)); + if (params[1] == 0) + throw runtime_error( + "input out of range for a " + std::to_string(n_bits) + + "-bit signed integer: " + to_string(in)); + else + throw runtime_error( + "input out of range for a " + to_string(n_bits) + + "-bit fixed-point number with " + + to_string(params[1]) + "-bit precision: " + + to_string( + mpf_class(bigint(in)) * exp2(-params[1]))); } } diff --git a/GC/RepPrep.h b/GC/RepPrep.h index 34419de79..8806e7306 100644 --- a/GC/RepPrep.h +++ b/GC/RepPrep.h @@ -22,7 +22,6 @@ class RepPrep : public PersonalPrep, ShiftableTripleBuffer ReplicatedBase* protocol; public: - RepPrep(DataPositions& usage, ShareThread& thread); RepPrep(DataPositions& usage, int input_player = PersonalPrep::SECURE); ~RepPrep(); diff --git a/GC/RepPrep.hpp b/GC/RepPrep.hpp index 5b2facacc..1c91fd395 100644 --- a/GC/RepPrep.hpp +++ b/GC/RepPrep.hpp @@ -16,13 +16,6 @@ namespace GC { -template -RepPrep::RepPrep(DataPositions& usage, ShareThread& thread) : - RepPrep(usage) -{ - (void) thread; -} - template RepPrep::RepPrep(DataPositions& usage, int input_player) : PersonalPrep(usage, input_player), protocol(0) diff --git a/GC/RuntimeBranching.h b/GC/RuntimeBranching.h index 20e062e8f..6ba0faf06 100644 --- a/GC/RuntimeBranching.h +++ b/GC/RuntimeBranching.h @@ -31,6 +31,11 @@ class RuntimeBranching { tainted = true; } + + bool is_tainted() + { + return tainted; + } }; } /* namespace GC */ diff --git a/GC/Secret.h b/GC/Secret.h index f97ad7b3f..6b37aa21c 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -84,7 +84,6 @@ class Secret template static void store_clear_in_dynamic(U& mem, const vector& accesses) { T::store_clear_in_dynamic(mem, accesses); } - static void output(T& reg); template static void load(vector< ReadAccess >& accesses, const U& mem); @@ -113,7 +112,7 @@ class Secret { T::inputbvec(processor, input_proc, args); } template static void reveal_inst(Processor& processor, const vector& args) - { processor.reveal(args); } + { T::reveal_inst(processor, args); } template static void trans(Processor& processor, int n_inputs, const vector& args); @@ -148,7 +147,6 @@ class Secret } void invert(int n, const Secret& x); void and_(int n, const Secret& x, const Secret& y, bool repeat); - void andrs(int n, const Secret& x, const Secret& y) { and_(n, x, y, true); } template void reveal(size_t n_bits, U& x); diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 562cc32ca..01c70247c 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -119,12 +119,6 @@ void Secret::store(U& mem, T::store(mem, accesses); } -template -void Secret::output(T& reg) -{ - reg.output(); -} - template Secret::Secret() { diff --git a/GC/SemiHonestRepPrep.h b/GC/SemiHonestRepPrep.h index 1d1be8307..1a1bd90d1 100644 --- a/GC/SemiHonestRepPrep.h +++ b/GC/SemiHonestRepPrep.h @@ -15,10 +15,6 @@ namespace GC class SemiHonestRepPrep : public RepPrep { public: - SemiHonestRepPrep(DataPositions& usage, ShareThread&) : - RepPrep(usage) - { - } SemiHonestRepPrep(DataPositions& usage, bool = false) : RepPrep(usage) { diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 32e756096..9fc3f4918 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -16,11 +16,6 @@ namespace GC { -SemiPrep::SemiPrep(DataPositions& usage, ShareThread&) : - SemiPrep(usage) -{ -} - SemiPrep::SemiPrep(DataPositions& usage, bool) : BufferPrep(usage), triple_generator(0) { diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 60751c382..97214c28d 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -26,7 +26,6 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer& thread); SemiPrep(DataPositions& usage, bool = true); ~SemiPrep(); diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index f8252f53e..ae10b5222 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -73,6 +73,9 @@ class SemiSecret : public SemiShare, public ShareSecret void xor_(int n, const SemiSecret& x, const SemiSecret& y) { *this = BitVec(x ^ y).mask(n); } + void xor_bit(int i, const SemiSecret& bit) + { *this ^= bit << i; } + void reveal(size_t n_bits, Clear& x); SemiSecret lsb() diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 357f91f08..9ea0d2f68 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -161,6 +161,9 @@ class RepSecretBase : public FixedVec, public ShareSecret This get_bit(int i) { return (*this >> i) & 1; } + + void xor_bit(int i, const This& bit) + { *this ^= bit << i; } }; template diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 86b1d95cb..5f995e808 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -32,9 +32,9 @@ class ShareThread Preprocessing& DataF; - ShareThread(const Names& N, OnlineOptions& opts, DataPositions& usage); - ShareThread(const Names& N, OnlineOptions& opts, Player& P, - typename T::mac_key_type mac_key, DataPositions& usage); + ShareThread(Preprocessing& prep); + ShareThread(Preprocessing& prep, Player& P, + typename T::mac_key_type mac_key); virtual ~ShareThread(); virtual typename T::MC* new_mc(typename T::mac_key_type mac_key) @@ -54,6 +54,7 @@ class StandaloneShareThread : public ShareThread, public Thread DataPositions usage; StandaloneShareThread(int i, ThreadMaster& master); + ~StandaloneShareThread(); void pre_run(); void post_run() { ShareThread::post_run(); } diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index c484e9e28..14d496115 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -18,26 +18,28 @@ namespace GC template StandaloneShareThread::StandaloneShareThread(int i, ThreadMaster& master) : - ShareThread(master.N, master.opts, usage), Thread(i, master) + ShareThread(*Preprocessing::get_new(master.opts.live_prep, + master.N, usage)), + Thread(i, master) { } template -ShareThread::ShareThread(const Names& N, OnlineOptions& opts, DataPositions& usage) : - P(0), MC(0), protocol(0), DataF( - opts.live_prep ? - *static_cast*>(new typename T::LivePrep( - usage, *this)) : - *static_cast*>(new BitPrepFiles(N, - get_prep_sub_dir(PREP_DIR, N.num_players()), - usage, BaseMachine::thread_num))) +StandaloneShareThread::~StandaloneShareThread() { + delete &this->DataF; } template -ShareThread::ShareThread(const Names& N, OnlineOptions& opts, Player& P, - typename T::mac_key_type mac_key, DataPositions& usage) : - ShareThread(N, opts, usage) +ShareThread::ShareThread(Preprocessing& prep) : + P(0), MC(0), protocol(0), DataF(prep) +{ +} + +template +ShareThread::ShareThread(Preprocessing& prep, Player& P, + typename T::mac_key_type mac_key) : + ShareThread(prep) { pre_run(P, mac_key); } @@ -45,7 +47,6 @@ ShareThread::ShareThread(const Names& N, OnlineOptions& opts, Player& P, template ShareThread::~ShareThread() { - delete &DataF; if (MC) delete MC; if (protocol) @@ -76,12 +77,6 @@ void ShareThread::post_run() { protocol->check(); MC->Check(*this->P); -#ifndef INSECURE -#ifdef VERBOSE - cerr << "Removing used pre-processed data" << endl; -#endif - DataF.prune(); -#endif } template diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 38f3d4326..5487c41b2 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -58,10 +58,8 @@ void Thread::run() P = new PlainPlayer(N, id); processor.open_input_file(N.my_num(), thread_num, master.opts.cmd_private_input_file); - processor.out.activate(N.my_num() == 0 or master.opts.interactive); - processor.setup_redirection(P->my_num(), thread_num, master.opts); - if (processor.stdout_redirect_file.is_open()) - processor.out.redirect_to_file(processor.stdout_redirect_file); + processor.setup_redirection(P->my_num(), thread_num, master.opts, + processor.out); done.push(0); pre_run(); diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 5f16229c0..060e9f118 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -58,6 +58,16 @@ Thread* ThreadMaster::new_thread(int i) template void ThreadMaster::run() { +#ifndef INSECURE + if (not opts.live_prep) + { + cerr + << "Preprocessing from file not supported by binary virtual machines" + << endl; + exit(1); + } +#endif + P = new PlainPlayer(N, "main"); machine.load_schedule(progname); diff --git a/GC/TinierShare.h b/GC/TinierShare.h index 4a57bc46c..f65835d88 100644 --- a/GC/TinierShare.h +++ b/GC/TinierShare.h @@ -106,11 +106,6 @@ class TinierShare: public Share_, SemiShare>, party.MC->get_alphai()); } - void random() - { - *this = get_party().DataF.get_part().get_bit(); - } - This lsb() const { return *this; diff --git a/GC/TinierSharePrep.h b/GC/TinierSharePrep.h index b2fbf9aac..34beaf6fb 100644 --- a/GC/TinierSharePrep.h +++ b/GC/TinierSharePrep.h @@ -25,7 +25,6 @@ class TinierSharePrep : public PersonalPrep MascotParams params; typedef typename T::whole_type secret_type; - ShareThread& thread; void buffer_triples(); void buffer_squares() { throw not_implemented(); } @@ -39,8 +38,6 @@ class TinierSharePrep : public PersonalPrep void init_real(Player& P); public: - TinierSharePrep(DataPositions& usage, ShareThread& thread, - int input_player = PersonalPrep::SECURE); TinierSharePrep(DataPositions& usage, int input_player = PersonalPrep::SECURE); TinierSharePrep(SubProcessor*, DataPositions& usage); diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp index 10711b70b..57e759b9c 100644 --- a/GC/TinierSharePrep.hpp +++ b/GC/TinierSharePrep.hpp @@ -15,16 +15,8 @@ namespace GC template TinierSharePrep::TinierSharePrep(DataPositions& usage, int input_player) : - TinierSharePrep(usage, ShareThread::s(), input_player) -{ -} - -template -TinierSharePrep::TinierSharePrep(DataPositions& usage, - ShareThread& thread, int input_player) : PersonalPrep(usage, input_player), triple_generator(0), - real_triple_generator(0), - thread(thread) + real_triple_generator(0) { } @@ -87,6 +79,7 @@ void TinierSharePrep::buffer_inputs(int player) template void GC::TinierSharePrep::buffer_bits() { + auto& thread = ShareThread::s(); this->bits.push_back( BufferPrep::get_random_from_inputs(thread.P->num_players())); } diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index eae76ab59..2b8a11b79 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -14,6 +14,7 @@ template void TinierSharePrep::init_real(Player& P) { assert(real_triple_generator == 0); + auto& thread = ShareThread::s(); real_triple_generator = new typename T::whole_type::TripleGenerator( BaseMachine::s().fresh_ot_setup(), P.N, -1, OnlineOptions::singleton.batch_size, 1, params, @@ -24,6 +25,7 @@ void TinierSharePrep::init_real(Player& P) template void TinierSharePrep::buffer_secret_triples() { + auto& thread = ShareThread::s(); auto& triple_generator = real_triple_generator; assert(triple_generator != 0); params.generateBits = false; diff --git a/GC/TinySecret.h b/GC/TinySecret.h index cbd15ee43..9b6c84782 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -58,6 +58,11 @@ class VectorSecret : public Secret return part_type::size() * default_length; } + static void specification(octetStream& os) + { + T::specification(os); + } + static void read_or_generate_mac_key(string directory, const Player& P, mac_key_type& key) { @@ -150,6 +155,17 @@ class VectorSecret : public Secret return this->get_reg(i); } + void xor_bit(size_t i, const T& bit) + { + if (i < this->get_regs().size()) + XOR(this->get_reg(i), this->get_reg(i), bit); + else + { + this->resize_regs(i + 1); + this->get_reg(i) = bit; + } + } + void output(ostream& s, bool human = true) const { assert(this->get_regs().size() == default_length); @@ -179,6 +195,12 @@ class VectorSecret : public Secret { inputter.finalize(from, n_bits).mask(*this, n_bits); } + + void random_bit() + { + auto& thread = GC::ShareThread::s(); + *this = thread.DataF.get_part().get_bit(); + } }; template diff --git a/GC/TinyShare.h b/GC/TinyShare.h index 0562e7f1c..5712de618 100644 --- a/GC/TinyShare.h +++ b/GC/TinyShare.h @@ -74,13 +74,6 @@ class TinyShare : public Spdz2kShare<1, S>, public ShareSecret> *this = super::constant(input, party.P->my_num(), party.MC->get_alphai()); } - - void random() - { - TinySecret tmp; - this->get_party().DataF.get_one(DATA_BIT, tmp); - *this = tmp.get_reg(0); - } }; } /* namespace GC */ diff --git a/Machines/TripleMachine.cpp b/Machines/TripleMachine.cpp index b3044d5a4..82cde5e8d 100644 --- a/Machines/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -158,8 +158,8 @@ GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i, { if (output and i == 0) { - T::clear::template generate_setup(PREP_DIR, nplayers, 128); prep_data_dir = get_prep_sub_dir(PREP_DIR, nplayers); + T::clear::write_setup(prep_data_dir); write_mac_key(prep_data_dir, my_num, nplayers, mac_key); } diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index d70e0fe2b..8525b0671 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -59,6 +59,9 @@ int main(int argc, const char** argv) online_opts.live_prep, online_opts).run(); \ break; X(64) X(128) X(256) X(192) X(384) X(512) +#ifdef RING_SIZE + X(RING_SIZE) +#endif #undef X default: cerr << "Not compiled for " << R << "-bit rings" << endl; diff --git a/Machines/real-bmr-party.cpp b/Machines/real-bmr-party.cpp index dd13d1731..42000ddf8 100644 --- a/Machines/real-bmr-party.cpp +++ b/Machines/real-bmr-party.cpp @@ -3,6 +3,8 @@ * */ +#define NO_MIXED_CIRCUITS + #include "BMR/RealProgramParty.hpp" #include "Machines/SPDZ.hpp" diff --git a/Makefile b/Makefile index fcfbee413..9d634c0e9 100644 --- a/Makefile +++ b/Makefile @@ -97,16 +97,15 @@ replicated: rep-field rep-ring rep-bin spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Offline.x mascot: mascot-party.x spdz2k mama-party.x -tldr: - -echo ARCH = -march=native >> CONFIG.mine - $(MAKE) mascot-party.x - ifeq ($(OS), Darwin) tldr: mac-setup else -tldr: mpir +tldr: mpir linux-machine-setup endif +tldr: + $(MAKE) mascot-party.x + ifeq ($(MACHINE), aarch64) tldr: simde/simde endif @@ -144,8 +143,6 @@ static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Mac Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) Processor/PrepBase.o $(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS) -Check-Offline.x: $(PROCESSOR) - ot.x: $(OT) $(COMMON) Machines/OText_main.o Machines/OTMachine.o $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) @@ -285,11 +282,21 @@ mpir: mpir-setup -echo MY_CFLAGS += -I./local/include >> CONFIG.mine -echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine -mac-setup: +mac-setup: mac-machine-setup brew install openssl boost libsodium mpir yasm ntl - -echo MY_CFLAGS += -I/usr/local/opt/openssl/include >> CONFIG.mine - -echo MY_LDLIBS += -L/usr/local/opt/openssl/lib >> CONFIG.mine - -echo USE_NTL = 1 >> CONFIG.mine + -echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I/opt/homebrew/opt/openssl/include -I/opt/homebrew/include >> CONFIG.mine + -echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L/opt/homebrew/lib -L/opt/homebrew/opt/openssl/lib >> CONFIG.mine +# -echo USE_NTL = 1 >> CONFIG.mine + +ifeq ($(MACHINE), aarch64) +mac-machine-setup: + -echo ARCH = >> CONFIG.mine +linux-machine-setup: + -echo ARCH = -march=armv8.2-a+crypto >> CONFIG.mine +else +mac-machine-setup: +linux-machine-setup: +endif simde/simde: git submodule update --init simde diff --git a/Math/FixedVec.h b/Math/FixedVec.h index de9360664..55983e0b0 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -49,6 +49,11 @@ class FixedVec return string(1, T::type_char()); } + static void specification(octetStream& os) + { + T::specification(os); + } + template static FixedVec Mul(const FixedVec& a, const V& b) { diff --git a/Math/Integer.h b/Math/Integer.h index 5308e6537..8104724c0 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -35,6 +35,8 @@ class IntBase : public ValueInterface static int length() { return N_BITS; } static string type_string() { return "integer"; } + static void specification(octetStream& os); + static void init_default(int lgp) { (void)lgp; } static bool allows(Dtype type) { return type <= DATA_BIT; } @@ -126,6 +128,8 @@ class Integer : public IntBase Integer(const bigint& x) { *this = (x > 0) ? x.get_ui() : -x.get_ui(); } template Integer(const Z2& x) : Integer(x.get_limb(0)) {} + template + Integer(const SignedZ2& x); template Integer(const gfp_& x); Integer(int128 x) : Integer(x.get_lower()) {} diff --git a/Math/Integer.hpp b/Math/Integer.hpp index e48ad93a2..12d42ae43 100644 --- a/Math/Integer.hpp +++ b/Math/Integer.hpp @@ -8,6 +8,12 @@ template const int IntBase::N_BITS; +template +inline void IntBase::specification(octetStream& os) +{ + os.store(sizeof(T)); +} + template void IntBase::output(ostream& s,bool human) const { @@ -42,6 +48,15 @@ void Integer::reqbl(int n) } } +template +Integer::Integer(const SignedZ2& x) +{ + if (K < N_BITS and x.negative()) + a = -(-x).get_limb(0); + else + a = x.get_limb(0); +} + inline Integer::Integer(const Integer& x, int n_bits) { diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 797d60db0..1a6fe41d1 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -139,6 +139,14 @@ void gf2n_::init_multiplication() } +template +void gf2n_::specification(octetStream& os) +{ + os.store(sizeof(U)); + os.store(degree()); +} + + /* Takes 8bit x and y and returns the 16 bit product in c1 and c0 ans = (c1<<8)^c0 where c1 and c0 are 8 bit diff --git a/Math/gf2n.h b/Math/gf2n.h index 13c3e3e73..add8627cf 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -76,6 +76,8 @@ class gf2n_ : public ValueInterface static string type_short() { return "2"; } static string type_string() { return "gf2n_"; } + static void specification(octetStream& os); + static int size() { return sizeof(a); } static int size_in_bits() { return sizeof(a) * 8; } diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 9848a862b..85a668a74 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -171,6 +171,29 @@ class gf2n_long : public gf2n_ } }; +#if defined(__aarch64__) && defined(__clang__) +inline __m128i my_slli(int128 x, int i) +{ + if (i < 64) + return int128(x.get_upper() << i, x.get_lower() << i).a; + else + return int128().a; +} + +inline __m128i my_srli(int128 x, int i) +{ + if (i < 64) + return int128(x.get_upper() >> i, x.get_lower() >> i).a; + else + return int128().a; +} + +#undef _mm_slli_epi64 +#undef _mm_srli_epi64 +#define _mm_slli_epi64 my_slli +#define _mm_srli_epi64 my_srli +#endif + inline int128 int128::operator<<(const int& other) const { int128 res(_mm_slli_epi64(a, other)); diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index 45db4e5d9..368bca4b4 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -15,7 +15,7 @@ Zp_Data gfpvar_::ZpD; template string gfpvar_::type_string() { - return "gfpvar"; + return "gfp"; } template @@ -30,6 +30,12 @@ char gfpvar_::type_char() return 'p'; } +template +void gfpvar_::specification(octetStream& os) +{ + os.store(pr()); +} + template int gfpvar_::length() { diff --git a/Math/gfpvar.h b/Math/gfpvar.h index 41d1ce2ba..438a935e5 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -44,6 +44,8 @@ class gfpvar_ static string type_short(); static char type_char(); + static void specification(octetStream& os); + static int length(); static int size(); static int size_in_bits(); diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 3794b69e9..d0b289b36 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -14,21 +14,43 @@ void check_ssl_file(string filename) "You can use `Scripts/setup-ssl.sh `."); } -void ssl_error(string side, string pronoun, string other, string server) +void ssl_error(string side, string other, string me) { cerr << side << "-side handshake with " << other - << " failed. Make sure " << pronoun - << " have the necessary certificate (" << PREP_DIR << server - << ".pem in the default configuration)," + << " failed. Make sure both sides " + << " have the necessary certificate (" << PREP_DIR << me + << ".pem in the default configuration on their side and " + << PREP_DIR << other << ".pem on ours)," << " and run `c_rehash ` on its location." << endl << "The certificates should be the same on every host. " << "Also make sure that it's still valid. Certificates generated " << "with `Scripts/setup-ssl.sh` expire after a month." << endl; + cerr << "See also " + "https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html" + "#handshake-failures" << endl; + + string ids[2]; + ids[side == "Client"] = other; + ids[side != "Client"] = me; + cerr << "Signature (should match the other side): "; + for (int i = 0; i < 2; i++) + { + auto filename = PREP_DIR + ids[i] + ".pem"; + ifstream cert(filename); + stringstream buffer; + buffer << cert.rdbuf(); + if (buffer.str().empty()) + cerr << "<'" << filename << "' not found>"; + else + cerr << octetStream(buffer.str()).hash(); + if (i == 0) + cerr << "/"; + } + cerr << endl; } CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : - MultiPlayer(Nms), plaintext_player(Nms, id_base), - other_player(Nms, id_base + "recv"), + MultiPlayer(Nms), ctx("P" + to_string(my_num())) { sockets.resize(num_players()); @@ -36,6 +58,16 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : senders.resize(num_players()); receivers.resize(num_players()); + vector plaintext_sockets[2]; + + for (int i = 0; i < 2; i++) + { + PlainPlayer player(Nms, id_base + (i ? "recv" : "")); + plaintext_sockets[i] = player.sockets; + close_client_socket(player.socket(my_num())); + player.sockets.clear(); + } + for (int i = 0; i < (int)sockets.size(); i++) { if (i == my_num()) @@ -47,9 +79,9 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : continue; } - sockets[i] = new ssl_socket(io_service, ctx, plaintext_player.socket(i), + sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[0][i], "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); - other_sockets[i] = new ssl_socket(io_service, ctx, other_player.socket(i), + other_sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[1][i], "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); senders[i] = new Sender(i < my_num() ? sockets[i] : other_sockets[i]); @@ -64,10 +96,6 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) : CryptoPlayer::~CryptoPlayer() { - close_client_socket(plaintext_player.socket(my_num())); - close_client_socket(other_player.socket(my_num())); - plaintext_player.sockets.clear(); - other_player.sockets.clear(); for (int i = 0; i < num_players(); i++) { delete sockets[i]; diff --git a/Networking/CryptoPlayer.h b/Networking/CryptoPlayer.h index d3bf80bfe..287f5c66f 100644 --- a/Networking/CryptoPlayer.h +++ b/Networking/CryptoPlayer.h @@ -20,7 +20,6 @@ */ class CryptoPlayer : public MultiPlayer { - PlainPlayer plaintext_player, other_player; ssl_ctx ctx; boost::asio::io_service io_service; diff --git a/Networking/Player.h b/Networking/Player.h index 668c097a7..033aa3bd1 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -373,6 +373,7 @@ class MultiPlayer : public Player T send_to_self_socket; T socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; } + T socket(int i) const { return sockets[i]; } friend class CryptoPlayer; @@ -381,8 +382,6 @@ class MultiPlayer : public Player virtual ~MultiPlayer(); - T socket(int i) const { return sockets[i]; } - // Send/Receive data to/from player i void send_long(int i, long a) const; long receive_long(int i) const; diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index a1fe34c84..fd064cd2e 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -109,7 +109,9 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) throw runtime_error( string() + "cannot connect from " + my_name + " to " + hostname + ":" + to_string(Portnum) + " after " + to_string(attempts) - + " attempts in one minute because " + strerror(connect_errno)); + + " attempts in one minute because " + strerror(connect_errno) + ". " + "https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#" + "connection-failures has more information on port requirements."); } freeaddrinfo(ai); diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index 20f5594b6..8989a0a10 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -16,7 +16,7 @@ typedef boost::asio::io_service ssl_service; void check_ssl_file(string filename); -void ssl_error(string side, string pronoun, string other, string server); +void ssl_error(string side, string other, string server); class ssl_ctx : public boost::asio::ssl::context { @@ -55,7 +55,7 @@ class ssl_socket : public boost::asio::ssl::stream handshake(ssl_socket::client); } catch (...) { - ssl_error("Client", "we", other, other); + ssl_error("Client", other, me); throw; } else @@ -65,7 +65,7 @@ class ssl_socket : public boost::asio::ssl::stream handshake(ssl_socket::server); } catch (...) { - ssl_error("Server", "they", other, me); + ssl_error("Server", other, me); throw; } diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index bc799e4a0..732850bc4 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -187,7 +187,13 @@ void NPartyTripleGenerator::generate() if (thread_num != 0) ss << "-" << thread_num; if (machine.output) + { outputFile.open(ss.str().c_str()); + if (machine.generateMACs or not T::clear::invertible) + file_signature().output(outputFile); + else + file_signature().output(outputFile); + } if (machine.generateBits) generateBits(); diff --git a/Processor/DataPositions.cpp b/Processor/DataPositions.cpp index 9eaa81e56..c32eb019f 100644 --- a/Processor/DataPositions.cpp +++ b/Processor/DataPositions.cpp @@ -89,6 +89,13 @@ DataPositions DataPositions::operator-(const DataPositions& other) const return res; } +DataPositions DataPositions::operator+(const DataPositions& other) const +{ + DataPositions res = *this; + res.increase(other); + return res; +} + void DataPositions::print_cost() const { ifstream file("cost"); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 069a36223..8f44ed253 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -19,6 +19,11 @@ using namespace std; template class dabit; +namespace GC +{ +template class ShareThread; +} + class DataTag { int t[4]; @@ -74,6 +79,7 @@ class DataPositions void increase(const DataPositions& delta); DataPositions& operator-=(const DataPositions& delta); DataPositions operator-(const DataPositions& delta) const; + DataPositions operator+(const DataPositions& delta) const; void print_cost() const; bool empty() const; bool any_more(const DataPositions& other) const; @@ -84,10 +90,15 @@ template class Data_Files; template class Machine; template class SubProcessor; +/** + * Abstract base class for preprocessing + */ template class Preprocessing : public PrepBase { protected: + static const bool use_part = false; + DataPositions& usage; map, vector>> edabits; @@ -114,6 +125,8 @@ class Preprocessing : public PrepBase template static Preprocessing* get_new(Machine& machine, DataPositions& usage, SubProcessor* proc); + static Preprocessing* get_new(bool live_prep, const Names& N, + DataPositions& usage); static Preprocessing* get_live_prep(SubProcessor* proc, DataPositions& usage); @@ -144,11 +157,15 @@ class Preprocessing : public PrepBase void get_input(T& a, typename T::open_type& x, int i); void get(vector& S, DataTag tag, const vector& regs, int vector_size); + /// Get fresh random multiplication triple virtual array get_triple(int n_bits); virtual array get_triple_no_count(int n_bits); + /// Get fresh random bit virtual T get_bit(); + /// Get fresh random value in domain virtual T get_random(); - virtual void get_dabit(T&, typename T::bit_type&); + /// Store fresh daBit in ``a`` (arithmetic part) and ``b`` (binary part) + virtual void get_dabit(T& a, typename T::bit_type& b); virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); } virtual void get_edabits(bool strict, size_t size, T* a, vector& Sb, const vector& regs) @@ -156,6 +173,7 @@ class Preprocessing : public PrepBase template void get_edabit_no_count(bool, int n_bits, edabit& eb); template + /// Get fresh edaBit chunk edabitvec get_edabitvec(bool strict, int n_bits); virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); } @@ -270,13 +288,14 @@ class Data_Files Preprocessing& DataFp; Preprocessing& DataF2; + Preprocessing& DataFb; Data_Files(Machine& machine, SubProcessor* procp = 0, SubProcessor* proc2 = 0); Data_Files(const Names& N); ~Data_Files(); - DataPositions tellg(); + DataPositions tellg() { return usage; } void seekg(DataPositions& pos); void skip(const DataPositions& pos); void prune(); @@ -289,7 +308,7 @@ class Data_Files void reset_usage() { usage.reset(); skipped.reset(); } - NamedCommStats comm_stats() { return DataFp.comm_stats() + DataF2.comm_stats(); } + NamedCommStats comm_stats(); }; template inline @@ -407,6 +426,13 @@ inline void Data_Files::purge() { DataFp.purge(); DataF2.purge(); + DataFb.purge(); +} + +template +NamedCommStats Data_Files::comm_stats() +{ + return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats(); } #endif diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 702937951..6951ed2cc 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -5,6 +5,7 @@ #include "Processor/Processor.h" #include "Protocols/dabit.h" #include "Math/Setup.h" +#include "GC/BitPrepFiles.h" #include "Protocols/MascotPrep.hpp" @@ -28,6 +29,19 @@ Preprocessing* Preprocessing::get_new( machine.template prep_dir_prefix(), usage); } +template +Preprocessing* Preprocessing::get_new( + bool live_prep, const Names& N, + DataPositions& usage) +{ + if (live_prep) + return new typename T::LivePrep(usage); + else + return new GC::BitPrepFiles(N, + get_prep_sub_dir(PREP_DIR, N.num_players()), usage, + BaseMachine::thread_num); +} + template Sub_Data_Files::Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num) : @@ -96,7 +110,7 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, dabit_buffer.setup( PrepBase::get_filename(prep_data_dir, DATA_DABIT, - type_short, my_num, thread_num), 1, type_string, + type_short, my_num, thread_num), dabit::size(), type_string, DataPositions::dtype_names[DATA_DABIT]); input_buffers.resize(num_players); @@ -106,7 +120,7 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, type_short, i, my_num, thread_num); if (i == my_num) my_input_buffers.setup(filename, - T::size() * 3 / 2, type_string); + T::size() + T::clear::size(), type_string); else input_buffers[i].setup(filename, T::size(), type_string); @@ -122,7 +136,10 @@ Data_Files::Data_Files(Machine& machine, SubProcessor< SubProcessor* proc2) : usage(machine.get_N().num_players()), DataFp(*Preprocessing::get_new(machine, usage, procp)), - DataF2(*Preprocessing::get_new(machine, usage, proc2)) + DataF2(*Preprocessing::get_new(machine, usage, proc2)), + DataFb( + *Preprocessing::get_new(machine.live_prep, + machine.get_N(), usage)) { } @@ -130,7 +147,8 @@ template Data_Files::Data_Files(const Names& N) : usage(N.num_players()), DataFp(*new Sub_Data_Files(N, usage)), - DataF2(*new Sub_Data_Files(N, usage)) + DataF2(*new Sub_Data_Files(N, usage)), + DataFb(*new Sub_Data_Files(N, usage)) { } @@ -150,6 +168,7 @@ Data_Files::~Data_Files() DataF2.data_sent() * 1e-6 << " MB" << endl; #endif delete &DataF2; + delete &DataFb; } template @@ -166,6 +185,12 @@ Sub_Data_Files::~Sub_Data_Files() template void Sub_Data_Files::seekg(DataPositions& pos) { + if (T::LivePrep::use_part) + { + get_part().seekg(pos); + return; + } + DataFieldType field_type = T::clear::field_type(); for (int dtype = 0; dtype < N_DTYPE; dtype++) if (T::clear::allows(Dtype(dtype))) @@ -181,6 +206,7 @@ void Sub_Data_Files::seekg(DataPositions& pos) setup_extended(it->first); extended[it->first].seekg(it->second); } + dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]); } template @@ -188,6 +214,7 @@ void Data_Files::seekg(DataPositions& pos) { DataFp.seekg(pos); DataF2.seekg(pos); + DataFb.seekg(pos); usage = pos; } @@ -210,6 +237,9 @@ void Sub_Data_Files::prune() input_buffers[j].prune(); for (auto it : extended) it.second.prune(); + dabit_buffer.prune(); + if (part != 0) + part->prune(); } template @@ -217,6 +247,7 @@ void Data_Files::prune() { DataFp.prune(); DataF2.prune(); + DataFb.prune(); } template @@ -229,6 +260,7 @@ void Sub_Data_Files::purge() input_buffers[j].purge(); for (auto it : extended) it.second.purge(); + dabit_buffer.purge(); } template @@ -280,11 +312,12 @@ void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits, ifstream* f = new ifstream(filename); if (f->fail()) throw runtime_error("cannot open " + filename); + check_file_signature(*f, filename); edabit_buffers[n_bits] = f; } auto& buffer = *edabit_buffers[n_bits]; if (buffer.peek() == EOF) - buffer.seekg(0); + buffer.seekg(file_signature().get_length()); edabitvec eb; eb.input(n_bits, buffer); this->edabits[{strict, n_bits}].push_back(eb); diff --git a/Processor/Input.h b/Processor/Input.h index 1e5aa1035..9816c3578 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -15,6 +15,9 @@ using namespace std; class ArithmeticProcessor; +/** + * Abstract base for input protocols + */ template class InputBase { @@ -45,18 +48,28 @@ class InputBase InputBase(SubProcessor* proc); virtual ~InputBase(); + /// Initialize input round for ``player`` virtual void reset(int player) = 0; + /// Initialize input round for all players void reset_all(Player& P); + /// Schedule input from me virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; + /// Schedule input from other player virtual void add_other(int player, int n_bits = -1) = 0; + /// Schedule input from all players void add_from_all(const clear& input); + /// Send my inputs virtual void send_mine() = 0; + /// Run input protocol for all players virtual void exchange(); + /// Get share for next input of mine virtual T finalize_mine() = 0; + /// Store share for next input from ``player`` from buffer ``o`` in ``target`` virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; + /// Get share for next input from ``player` virtual T finalize(int player, int n_bits = -1); void raw_input(SubProcessor& proc, const vector& args, int size); diff --git a/Processor/InputTuple.h b/Processor/InputTuple.h index e7a1c9cef..af38b94b0 100644 --- a/Processor/InputTuple.h +++ b/Processor/InputTuple.h @@ -19,6 +19,11 @@ struct InputTuple static string type_string() { return T::type_string(); } + static void specification(octetStream& os) + { + T::specification(os); + } + InputTuple() {} InputTuple(const T& share, const typename T::open_type& value) : share(share), value(value) {} diff --git a/Processor/Instruction.h b/Processor/Instruction.h index ad04cefc2..ca062cbcb 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -14,6 +14,7 @@ using namespace std; template class Machine; template class Processor; class ArithmeticProcessor; +class SwitchableOutput; /* * Opcode constants @@ -306,12 +307,6 @@ enum RegType { MAX_REG_TYPE, }; -enum SecrecyType { - SECRET, - CLEAR, - MAX_SECRECY_TYPE -}; - template struct TempVars { typename sgf2n::clear ans2; @@ -387,6 +382,10 @@ class Instruction : public BaseInstruction void shuffle(ArithmeticProcessor& Proc) const; void bitdecint(ArithmeticProcessor& Proc) const; + + template + void print(SwitchableOutput& out, T* v, T* p = 0, T* s = 0, T* z = 0, + T* nan = 0) const; }; #endif diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 3d3db1fcf..72184d9e3 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -328,6 +328,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // write to external client, input is : opcode num_args, client_id, message_type, var1, var2 ... case WRITESOCKETC: + case WRITESOCKETS: case WRITESOCKETSHARE: case WRITESOCKETINT: num_var_args = get_int(s) - 3; @@ -336,8 +337,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) n = get_int(s); get_vector(num_var_args, start, s); break; - case WRITESOCKETS: - throw runtime_error("sending MACs to client not supported any more"); case READCLIENTPUBLICKEY: case INITSECURESOCKET: case RESPSECURESOCKET: @@ -1070,31 +1069,19 @@ inline void Instruction::execute(Processor& Proc) const } break; case PRINTREGPLAIN: - { - Proc.out << Proc.read_Cp(r[0]) << flush; - } - break; + print(Proc.out, &Proc.read_Cp(r[0])); + return; case CONDPRINTPLAIN: if (not Proc.read_Cp(r[0]).is_zero()) { - auto v = Proc.read_Cp(r[1]); - auto p = Proc.read_Cp(r[2]); - if (p.is_zero()) - Proc.out << v << flush; - else - Proc.out << bigint::get_float(v, p, {}, {}) << flush; + print(Proc.out, &Proc.read_Cp(r[1]), &Proc.read_Cp(r[2])); } - break; + return; case PRINTFLOATPLAIN: - { - auto nan = Proc.read_Cp(start[4]); - typename sint::clear v = Proc.read_Cp(start[0]); - typename sint::clear p = Proc.read_Cp(start[1]); - typename sint::clear z = Proc.read_Cp(start[2]); - typename sint::clear s = Proc.read_Cp(start[3]); - bigint::output_float(Proc.out, bigint::get_float(v, p, z, s), nan); - } - break; + print(Proc.out, &Proc.read_Cp(start[0]), &Proc.read_Cp(start[1]), + &Proc.read_Cp(start[2]), &Proc.read_Cp(start[3]), + &Proc.read_Cp(start[4])); + return; case CONDPRINTSTR: if (not Proc.read_Cp(r[0]).is_zero()) { @@ -1124,9 +1111,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.stop(n); break; case RUN_TAPE: - Proc.DataF.skip( - Proc.machine.run_tapes(start, &Proc.DataF.DataFp, - &Proc.share_thread.DataF)); + Proc.machine.run_tapes(start, Proc.DataF); break; case JOIN_TAPE: Proc.machine.join_tape(r[0]); @@ -1186,15 +1171,19 @@ inline void Instruction::execute(Processor& Proc) const Proc.read_socket_private(Proc.read_Ci(r[0]), start, n, true); break; case WRITESOCKETINT: - Proc.write_socket(INT, Proc.read_Ci(r[0]), r[1], start, n); + Proc.write_socket(INT, false, Proc.read_Ci(r[0]), r[1], start, n); break; case WRITESOCKETC: - Proc.write_socket(CINT, Proc.read_Ci(r[0]), r[1], start, n); + Proc.write_socket(CINT, false, Proc.read_Ci(r[0]), r[1], start, n); + break; + case WRITESOCKETS: + // Send shares + MACs + Proc.write_socket(SINT, true, Proc.read_Ci(r[0]), r[1], start, n); break; case WRITESOCKETSHARE: // Send only shares, no MACs // N.B. doesn't make sense to have a corresponding read instruction for this - Proc.write_socket(SINT, Proc.read_Ci(r[0]), r[1], start, n); + Proc.write_socket(SINT, false, Proc.read_Ci(r[0]), r[1], start, n); break; case WRITEFILESHARE: // Write shares to file system @@ -1323,4 +1312,29 @@ void Program::execute(Processor& Proc) const } } +template +void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) const +{ + if (size > 1) + out << "["; + for (int i = 0; i < size; i++) + { + if (p == 0) + out << v[i]; + else if (s == 0) + out << bigint::get_float(v[i], p[i], {}, {}); + else + { + assert(z != 0); + assert(nan != 0); + bigint::output_float(out, bigint::get_float(v[i], p[i], s[i], z[i]), + nan[i]); + } + if (i < size - 1) + out << ", "; + } + if (size > 1) + out << "]"; +} + #endif diff --git a/Processor/Machine.h b/Processor/Machine.h index d6ab84848..3f23dc9f9 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -42,9 +42,6 @@ class Machine : public BaseMachine typename sgf2n::mac_key_type alpha2i; typename sint::bit_type::mac_key_type alphabi; - // Keep record of used offline data - DataPositions pos; - Player* P; void load_program(const string& threadname, const string& filename); @@ -83,8 +80,8 @@ class Machine : public BaseMachine const Names& get_N() { return N; } - DataPositions run_tapes(const vector &args, Preprocessing *prep, - Preprocessing *bit_prep); + DataPositions run_tapes(const vector &args, + Data_Files& DataF); void fill_buffers(int thread_number, int tape_number, Preprocessing *prep, Preprocessing *bit_prep); @@ -93,7 +90,8 @@ class Machine : public BaseMachine Preprocessing *prep, true_type); template void fill_matmul(int, int, Preprocessing*, false_type) {} - DataPositions run_tape(int thread_number, int tape_number, int arg); + DataPositions run_tape(int thread_number, int tape_number, int arg, + const DataPositions& pos); DataPositions join_tape(int thread_number); void run(); diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index bd08aa39e..804dc51aa 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -92,9 +92,6 @@ Machine::Machine(int my_number, Names& playerNames, exit(1); } - // Keep record of used offline data - pos.set_num_players(N.num_players()); - load_schedule(progname_str); // remove persistence if necessary @@ -161,14 +158,16 @@ void Machine::load_program(const string& threadname, template DataPositions Machine::run_tapes(const vector& args, - Preprocessing* prep, Preprocessing* bit_prep) + Data_Files& DataF) { assert(args.size() % 3 == 0); for (unsigned i = 0; i < args.size(); i += 3) - fill_buffers(args[i], args[i + 1], prep, bit_prep); + fill_buffers(args[i], args[i + 1], &DataF.DataFp, &DataF.DataFb); DataPositions res(N.num_players()); for (unsigned i = 0; i < args.size(); i += 3) - res.increase(run_tape(args[i], args[i + 1], args[i + 2])); + res.increase( + run_tape(args[i], args[i + 1], args[i + 2], DataF.tellg() + res)); + DataF.skip(res); return res; } @@ -281,7 +280,7 @@ void Machine::fill_matmul(int thread_number, int tape_number, template DataPositions Machine::run_tape(int thread_number, int tape_number, - int arg) + int arg, const DataPositions& pos) { if (size_t(thread_number) >= tinfo.size()) throw overflow("invalid thread number", thread_number, tinfo.size()); @@ -294,7 +293,7 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, if (progs[tape_number].usage_unknown()) { #ifndef INSECURE - if (not opts.live_prep) + if (not opts.live_prep and thread_number != 0) { cerr << "Internally called tape " << tape_number << " has unknown offline data usage" << endl; @@ -328,7 +327,7 @@ void Machine::run() timer[0].start(); // run main tape - pos.increase(run_tape(0, 0, 0)); + run_tape(0, 0, 0, N.num_players()); join_tape(0); print_compiler(); @@ -341,8 +340,8 @@ void Machine::run() queues[i]->schedule(-1); } - // reset to sum actual usage - pos.reset(); + // sum actual usage + DataPositions pos(N.num_players()); #ifdef DEBUG_THREADS cerr << "Waiting for all clients to finish" << endl; diff --git a/Processor/Memory.h b/Processor/Memory.h index 3c5093952..2c4a3d2e3 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -15,22 +15,29 @@ template istream& operator>>(istream& s,Memory& M); #include "Processor/Program.h" #include "Tools/CheckVector.h" +template +class MemoryPart : public CheckVector +{ +public: + void minimum_size(size_t size); +}; + template class Memory { public: - CheckVector MS; - CheckVector MC; + MemoryPart MS; + MemoryPart MC; - void resize_s(int sz) + void resize_s(size_t sz) { MS.resize(sz); } - void resize_c(int sz) + void resize_c(size_t sz) { MC.resize(sz); } - unsigned size_s() + size_t size_s() { return MS.size(); } - unsigned size_c() + size_t size_c() { return MC.size(); } template @@ -40,23 +47,23 @@ class Memory throw overflow("memory", i, M.size()); } - const typename T::clear& read_C(int i) const + const typename T::clear& read_C(size_t i) const { check_index(MC, i); return MC[i]; } - const T& read_S(int i) const + const T& read_S(size_t i) const { check_index(MS, i); return MS[i]; } - void write_C(unsigned int i,const typename T::clear& x) + void write_C(size_t i,const typename T::clear& x) { check_index(MC, i); MC[i]=x; } - void write_S(unsigned int i,const T& x) + void write_S(size_t i,const T& x) { check_index(MS, i); MS[i]=x; diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index 44e7d3432..c3c3e01bf 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -8,27 +8,23 @@ void Memory::minimum_size(RegType secret_type, RegType clear_type, const Program &program, const string& threadname) { (void) threadname; - unsigned sizes[MAX_SECRECY_TYPE]; - sizes[SECRET]= program.direct_mem(secret_type); - sizes[CLEAR] = program.direct_mem(clear_type); - if (sizes[SECRET] > size_s()) - { -#ifdef DEBUG_MEMORY - cerr << threadname << " needs more secret " << T::type_string() << " memory, resizing to " - << sizes[SECRET] << endl; -#endif - resize_s(sizes[SECRET]); - } - if (sizes[CLEAR] > size_c()) - { -#ifdef DEBUG_MEMORY - cerr << threadname << " needs more clear " << T::type_string() << " memory, resizing to " - << sizes[CLEAR] << endl; -#endif - resize_c(sizes[CLEAR]); - } + MS.minimum_size(program.direct_mem(secret_type)); + MC.minimum_size(program.direct_mem(clear_type)); } +template +void MemoryPart::minimum_size(size_t size) +{ + try + { + if (size > this->size()) + this->resize(size); + } + catch (bad_alloc&) + { + throw insufficient_memory(size, T::type_string()); + } +} template ostream& operator<<(ostream& s,const Memory& M) diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index b9901a0c6..cffaded40 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -8,6 +8,7 @@ #include "OfflineMachine.h" #include "Protocols/mac_key.hpp" +#include "Tools/Buffer.h" template template @@ -39,8 +40,8 @@ int OfflineMachine::run() T::bit_type::mac_key_type::init_field(); auto binary_mac_key = read_generate_write_mac_key< typename T::bit_type::part_type>(P); - GC::ShareThread thread(playerNames, - OnlineOptions::singleton, P, binary_mac_key, usage); + typename T::bit_type::LivePrep bit_prep(usage); + GC::ShareThread thread(bit_prep, P, binary_mac_key); generate(); generate(); @@ -74,6 +75,7 @@ void OfflineMachine::generate() if (my_usage > 0) { ofstream out(filename, iostream::out | iostream::binary); + file_signature().output(out); if (i == DATA_DABIT) { for (long long j = 0; @@ -108,6 +110,7 @@ void OfflineMachine::generate() if (n_inputs > 0) { ofstream out(filename, iostream::out | iostream::binary); + file_signature().output(out); InputTuple tuple; for (long long j = 0; j < DIV_CEIL(n_inputs, BUFFER_SIZE) * BUFFER_SIZE; j++) @@ -138,6 +141,7 @@ void OfflineMachine::generate() if (total > 0) { ofstream out(filename, ios::binary); + file_signature().output(out); for (int i = 0; i < DIV_CEIL(total, batch) * batch; i++) preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits, out); diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h index b0965ae0d..577ab9f44 100644 --- a/Processor/Online-Thread.h +++ b/Processor/Online-Thread.h @@ -26,7 +26,7 @@ class thread_info static void* Main_Func(void *ptr); - static void purge_preprocessing(const Names& N); + static void purge_preprocessing(const Names& N, int thread_num); template static void print_usage(ostream& o, const vector& regs, diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 1ef7da055..cb25b4261 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -352,7 +352,7 @@ void* thread_info::Main_Func(void* ptr) catch (...) { thread_info* ti = (thread_info*)ptr; - ti->purge_preprocessing(ti->machine->get_N()); + ti->purge_preprocessing(ti->machine->get_N(), ti->thread_num); throw; } #endif @@ -361,13 +361,17 @@ void* thread_info::Main_Func(void* ptr) template -void thread_info::purge_preprocessing(const Names& N) +void thread_info::purge_preprocessing(const Names& N, int thread_num) { cerr << "Purging preprocessed data because something is wrong" << endl; try { Data_Files df(N); df.purge(); + DataPositions pos; + Sub_Data_Files bit_df(N, pos, thread_num); + bit_df.get_part(); + bit_df.purge(); } catch(...) { diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index f728288ff..4e944d624 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -249,7 +249,7 @@ int OnlineMachine::run() catch(...) { if (not online_opts.live_prep) - thread_info::purge_preprocessing(playerNames); + thread_info::purge_preprocessing(playerNames, 0); throw; } #endif diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index df0ff7b49..03fa23793 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -36,13 +36,9 @@ OnlineOptions::OnlineOptions() : playerno(-1) } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, - const char** argv, int default_batch_size, bool default_live_prep, - bool variable_prime_length) : + const char** argv, false_type) : OnlineOptions() { - if (default_batch_size <= 0) - default_batch_size = batch_size; - opt.syntax = std::string(argv[0]) + " [OPTIONS] [] "; opt.add( @@ -78,6 +74,58 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--output-file" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "This player's number (required if not given before program name)", // Help description. + "-p", // Flag token. + "--player" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Verbose output", // Help description. + "-v", // Flag token. + "--verbose" // Flag token. + ); + opt.add( + "4", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Batch size for sacrifice (3-5, default: 4)", // Help description. + "-B", // Flag token. + "--bucket-size" // Flag token. + ); + + opt.parse(argc, argv); + + interactive = opt.isSet("-I"); + + opt.get("-IF")->getString(cmd_private_input_file); + opt.get("-OF")->getString(cmd_private_output_file); + + opt.get("--bucket-size")->getInt(bucket_size); + +#ifndef VERBOSE + verbose = opt.isSet("--verbose"); +#endif + + opt.resetArgs(); +} + +OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, + const char** argv, int default_batch_size, bool default_live_prep, + bool variable_prime_length) : + OnlineOptions(opt, argc, argv, false_type()) +{ + if (default_batch_size <= 0) + default_batch_size = batch_size; + string default_lgp = to_string(lgp); if (variable_prime_length) { @@ -121,15 +169,6 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "-L", // Flag token. "--live-preprocessing" // Flag token. ); - opt.add( - "", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "This player's number (required if not given before program name)", // Help description. - "-p", // Flag token. - "--player" // Flag token. - ); opt.add( to_string(default_batch_size).c_str(), // Default. @@ -170,28 +209,9 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "-d", // Flag token. "--direct" // Flag token. ); - opt.add( - "4", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Batch size for sacrifice (3-5, default: 4)", // Help description. - "-B", // Flag token. - "--bucket-size" // Flag token. - ); - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Verbose output", // Help description. - "-v", // Flag token. - "--verbose" // Flag token. - ); opt.parse(argc, argv); - interactive = opt.isSet("-I"); if (variable_prime_length) { opt.get("--lgp")->getInt(lgp); @@ -208,17 +228,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, opt.get("--memory")->getString(memtype); bits_from_squares = opt.isSet("-Q"); - opt.get("-IF")->getString(cmd_private_input_file); - opt.get("-OF")->getString(cmd_private_output_file); - direct = opt.isSet("--direct"); - opt.get("--bucket-size")->getInt(bucket_size); - -#ifndef VERBOSE - verbose = opt.isSet("--verbose"); -#endif - opt.resetArgs(); } diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 615afc4c1..32c80fc2b 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -31,6 +31,8 @@ class OnlineOptions bool verbose; OnlineOptions(); + OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, + false_type); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size = 0, bool default_live_prep = true, bool variable_prime_length = false); diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index b4fd58a99..b99d091e6 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -9,15 +9,8 @@ string PrepBase::get_suffix(int thread_num) { -#ifdef INSECURE (void) thread_num; return ""; -#else - if (thread_num >= 0) - return "-T" + to_string(thread_num); - else - return ""; -#endif } string PrepBase::get_filename(const string& prep_data_dir, diff --git a/Processor/Processor.h b/Processor/Processor.h index b62c64c24..d9141855c 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -31,7 +31,7 @@ class SubProcessor DataPositions bit_usage; - void resize(int size) { C.resize(size); S.resize(size); } + void resize(size_t size) { C.resize(size); S.resize(size); } template friend class Processor; template friend class SPDZ; @@ -64,10 +64,10 @@ class SubProcessor void muls(const vector& reg, int size); void mulrs(const vector& reg); void dotprods(const vector& reg, int size); - void matmuls(const vector& source, const Instruction& instruction, int a, - int b); - void matmulsm(const CheckVector& source, const Instruction& instruction, int a, - int b); + void matmuls(const vector& source, const Instruction& instruction, size_t a, + size_t b); + void matmulsm(const CheckVector& source, const Instruction& instruction, size_t a, + size_t b); void conv2ds(const Instruction& instruction); void input_personal(const vector& args); @@ -82,12 +82,12 @@ class SubProcessor return C; } - T& get_S_ref(int i) + T& get_S_ref(size_t i) { return S[i]; } - typename T::clear& get_C_ref(int i) + typename T::clear& get_C_ref(size_t i) { return C[i]; } @@ -136,11 +136,11 @@ class ArithmeticProcessor : public ProcessorBase return thread_num; } - const long& read_Ci(int i) const + const long& read_Ci(size_t i) const { return Ci[i]; } - long& get_Ci_ref(int i) + long& get_Ci_ref(size_t i) { return Ci[i]; } - void write_Ci(int i,const long& x) + void write_Ci(size_t i, const long& x) { Ci[i]=x; } CheckVector& get_Ci() { return Ci; } @@ -190,30 +190,30 @@ class Processor : public ArithmeticProcessor const Program& program); ~Processor(); - const typename sgf2n::clear& read_C2(int i) const + const typename sgf2n::clear& read_C2(size_t i) const { return Proc2.C[i]; } - const sgf2n& read_S2(int i) const + const sgf2n& read_S2(size_t i) const { return Proc2.S[i]; } - typename sgf2n::clear& get_C2_ref(int i) + typename sgf2n::clear& get_C2_ref(size_t i) { return Proc2.C[i]; } - sgf2n& get_S2_ref(int i) + sgf2n& get_S2_ref(size_t i) { return Proc2.S[i]; } - void write_C2(int i,const typename sgf2n::clear& x) + void write_C2(size_t i,const typename sgf2n::clear& x) { Proc2.C[i]=x; } - void write_S2(int i,const sgf2n& x) + void write_S2(size_t i,const sgf2n& x) { Proc2.S[i]=x; } - const typename sint::clear& read_Cp(int i) const + const typename sint::clear& read_Cp(size_t i) const { return Procp.C[i]; } - const sint & read_Sp(int i) const + const sint & read_Sp(size_t i) const { return Procp.S[i]; } - typename sint::clear& get_Cp_ref(int i) + typename sint::clear& get_Cp_ref(size_t i) { return Procp.C[i]; } - sint & get_Sp_ref(int i) + sint & get_Sp_ref(size_t i) { return Procp.S[i]; } - void write_Cp(int i,const typename sint::clear& x) + void write_Cp(size_t i,const typename sint::clear& x) { Procp.C[i]=x; } - void write_Sp(int i,const sint & x) + void write_Sp(size_t i,const sint & x) { Procp.S[i]=x; } void check(); @@ -229,8 +229,8 @@ class Processor : public ArithmeticProcessor // Access to external client sockets for reading clear/shared data void read_socket_ints(int client_id, const vector& registers, int size); - void write_socket(const RegType reg_type, int socket_id, int message_type, - const vector& registers, int size); + void write_socket(const RegType reg_type, bool send_macs, int socket_id, + int message_type, const vector& registers, int size); void read_socket_vector(int client_id, const vector& registers, int size); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index dd1e82382..ebbc1c8cc 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -70,7 +70,7 @@ Processor::Processor(int thread_num,Player& P, const Program& program) : ArithmeticProcessor(machine.opts, thread_num),DataF(machine, &Procp, &Proc2),P(P), MC2(MC2),MCp(MCp),machine(machine), - share_thread(machine.get_N(), machine.opts, P, machine.get_bit_mac_key(), DataF.usage), + share_thread(DataF.DataFb, P, machine.get_bit_mac_key()), Procb(machine.bit_memories), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), privateOutput2(Proc2),privateOutputp(Procp), @@ -94,21 +94,8 @@ Processor::Processor(int thread_num,Player& P, secure_prng.ReSeed(); shared_prng.SeedGlobally(P, false); - // only output on party 0 if not interactive - bool always_stdout = machine.opts.cmd_private_output_file == "."; - bool output = P.my_num() == 0 or machine.opts.interactive or always_stdout; - out.activate(output); - Procb.out.activate(output); - - if (not always_stdout) - setup_redirection(P.my_num(), thread_num, opts); - - if (stdout_redirect_file.is_open()) - { - out.redirect_to_file(stdout_redirect_file); - Procb.out.redirect_to_file(stdout_redirect_file); - } - + setup_redirection(P.my_num(), thread_num, opts, out); + Procb.out = out; } @@ -266,8 +253,9 @@ void Processor::split(const Instruction& instruction) // If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to // determine the data structure being sent in a message. template -void Processor::write_socket(const RegType reg_type, int socket_id, - int message_type, const vector& registers, int size) +void Processor::write_socket(const RegType reg_type, + bool send_macs, int socket_id, int message_type, + const vector& registers, int size) { int m = registers.size(); socket_stream.reset_write_head(); @@ -283,9 +271,12 @@ void Processor::write_socket(const RegType reg_type, int socket_id, { if (reg_type == SINT) { - // Send vector of secret shares - get_Sp_ref(registers[i] + j).pack(socket_stream, - sint::get_rec_factor(P.my_num(), P.num_players())); + // Send vector of secret shares and optionally macs + if (send_macs) + get_Sp_ref(registers[i] + j).pack(socket_stream); + else + get_Sp_ref(registers[i] + j).pack(socket_stream, + sint::get_rec_factor(P.my_num(), P.num_players())); } else if (reg_type == CINT) { @@ -522,7 +513,7 @@ void SubProcessor::dotprods(const vector& reg, int size) template void SubProcessor::matmuls(const vector& source, - const Instruction& instruction, int a, int b) + const Instruction& instruction, size_t a, size_t b) { auto& dim = instruction.get_start(); auto A = source.begin() + a; @@ -549,7 +540,7 @@ void SubProcessor::matmuls(const vector& source, template void SubProcessor::matmulsm(const CheckVector& source, - const Instruction& instruction, int a, int b) + const Instruction& instruction, size_t a, size_t b) { auto& dim = instruction.get_start(); auto C = S.begin() + (instruction.get_r(0)); diff --git a/Processor/ProcessorBase.cpp b/Processor/ProcessorBase.cpp index f2d34bb97..0fa1ab529 100644 --- a/Processor/ProcessorBase.cpp +++ b/Processor/ProcessorBase.cpp @@ -5,6 +5,11 @@ #include "ProcessorBase.hpp" +ProcessorBase::ProcessorBase() : + input_counter(0), arg(0) +{ +} + string ProcessorBase::get_parameterized_filename(int my_num, int thread_num, const string& prefix) { string filename = prefix + "-P" + to_string(my_num) + "-" + to_string(thread_num); @@ -22,12 +27,18 @@ void ProcessorBase::open_input_file(int my_num, int thread_num, } void ProcessorBase::setup_redirection(int my_num, int thread_num, - OnlineOptions& opts) + OnlineOptions& opts, SwitchableOutput& out) { - if (not opts.cmd_private_output_file.empty()) + // only output on party 0 if not interactive + bool always_stdout = opts.cmd_private_output_file == "."; + bool output = my_num == 0 or opts.interactive or always_stdout; + out.activate(output); + + if (not (opts.cmd_private_output_file.empty() or always_stdout)) { const string stdout_filename = get_parameterized_filename(my_num, thread_num, opts.cmd_private_output_file); stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out); + out.redirect_to_file(stdout_redirect_file); } } diff --git a/Processor/ProcessorBase.h b/Processor/ProcessorBase.h index 3c7d71166..d30de5d30 100644 --- a/Processor/ProcessorBase.h +++ b/Processor/ProcessorBase.h @@ -12,6 +12,7 @@ using namespace std; #include "Tools/ExecutionStats.h" +#include "Tools/SwitchableOutput.h" #include "OnlineOptions.h" class ProcessorBase @@ -21,6 +22,7 @@ class ProcessorBase ifstream input_file; string input_filename; + size_t input_counter; protected: // Optional argument to tape @@ -34,6 +36,8 @@ class ProcessorBase ofstream stdout_redirect_file; + ProcessorBase(); + void pushi(long x) { stacki.push(x); } void popi(long& x) { x = stacki.top(); stacki.pop(); } @@ -55,7 +59,8 @@ class ProcessorBase template T get_input(istream& is, const string& input_filename, const int* params); - void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts); + void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts, + SwitchableOutput& out); }; #endif /* PROCESSOR_PROCESSORBASE_H_ */ diff --git a/Processor/ProcessorBase.hpp b/Processor/ProcessorBase.hpp index 06b26a996..9af3e9928 100644 --- a/Processor/ProcessorBase.hpp +++ b/Processor/ProcessorBase.hpp @@ -42,8 +42,9 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co res.read(input_file, params); if (input_file.fail()) { - throw input_error(T::NAME, input_filename, input_file); + throw input_error(T::NAME, input_filename, input_file, input_counter); } + input_counter++; return res; } diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index dfeb02bfd..add3f43cc 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -42,7 +42,7 @@ RingMachine::RingMachine(int argc, const char** argv, case L: \ machine.template run, V>(); \ break; - X(64) X(72) X(128) + X(64) X(72) X(128) X(192) #ifdef RING_SIZE X(RING_SIZE) #endif diff --git a/Programs/Source/mnist_49.mpc b/Programs/Source/mnist_49.mpc index 05218130a..da5b1de9b 100644 --- a/Programs/Source/mnist_49.mpc +++ b/Programs/Source/mnist_49.mpc @@ -8,6 +8,7 @@ import util program.options_from_args() sfix.set_precision_from_args(program) +MultiArray.disable_index_checks() n_examples = 11791 n_test = 1991 diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index a1250b5df..9dc8a6851 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -10,6 +10,7 @@ import util program.options_from_args() sfix.set_precision_from_args(program, adapt_ring=True) +MultiArray.disable_index_checks() if 'profile' in program.args: print('Compiling for profiling') diff --git a/Programs/Source/mnist_full_B.mpc b/Programs/Source/mnist_full_B.mpc index 84cfe615b..a7cdb3688 100644 --- a/Programs/Source/mnist_full_B.mpc +++ b/Programs/Source/mnist_full_B.mpc @@ -8,6 +8,7 @@ import util program.options_from_args() sfix.set_precision_from_args(program, adapt_ring=True) +MultiArray.disable_index_checks() if 'profile' in program.args: print('Compiling for profiling') diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index cd6603ea2..6ea76b260 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -8,6 +8,7 @@ import util program.options_from_args() sfix.set_precision_from_args(program, adapt_ring=True) +MultiArray.disable_index_checks() if 'profile' in program.args: print('Compiling for profiling') diff --git a/Programs/Source/mnist_full_D.mpc b/Programs/Source/mnist_full_D.mpc index 7ebe1904e..68d12f977 100644 --- a/Programs/Source/mnist_full_D.mpc +++ b/Programs/Source/mnist_full_D.mpc @@ -8,6 +8,7 @@ import util program.options_from_args() sfix.set_precision_from_args(program, True) +MultiArray.disable_index_checks() if 'profile' in program.args: print('Compiling for profiling') diff --git a/Programs/Source/tf.mpc b/Programs/Source/tf.mpc index e285dd917..0c51cf581 100644 --- a/Programs/Source/tf.mpc +++ b/Programs/Source/tf.mpc @@ -30,6 +30,8 @@ layers[0].X.input_from(0) for layer in layers: layer.input_from(0, raw='raw' in program.args) +sint(0).reveal().store_in_mem(0) + start_timer(1) opt.forward(1, keep_intermediate=False) stop_timer(1) diff --git a/Protocols/Atlas.h b/Protocols/Atlas.h index 1a6d66d4f..3dd34d173 100644 --- a/Protocols/Atlas.h +++ b/Protocols/Atlas.h @@ -8,6 +8,10 @@ #include "Replicated.h" +/** + * ATLAS protocol (simple version). + * Uses double sharings to reduce degree of Shamir secret sharing. + */ template class Atlas : public ProtocolBase { diff --git a/Protocols/AtlasPrep.h b/Protocols/AtlasPrep.h index 489f535af..666952d37 100644 --- a/Protocols/AtlasPrep.h +++ b/Protocols/AtlasPrep.h @@ -8,6 +8,9 @@ #include "ReplicatedPrep.h" +/** + * ATLAS preprocessing. + */ template class AtlasPrep : public ReplicatedPrep { @@ -21,6 +24,7 @@ class AtlasPrep : public ReplicatedPrep { } + /// Input tuples from random sharings void buffer_inputs(int player) { assert(this->protocol and this->proc); diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 71bc96fce..e0c24e49e 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -17,6 +17,9 @@ template class SubProcessor; template class MAC_Check_Base; class Player; +/** + * Beaver multiplication + */ template class Beaver : public ProtocolBase { diff --git a/Protocols/ChaiGearPrep.h b/Protocols/ChaiGearPrep.h index 21e18eeb6..fab6e21fe 100644 --- a/Protocols/ChaiGearPrep.h +++ b/Protocols/ChaiGearPrep.h @@ -8,6 +8,9 @@ #include "FHEOffline/SimpleGenerator.h" +/** + * HighGear/ChaiGear preprocessing + */ template class ChaiGearPrep : public MaliciousRingPrep { diff --git a/Protocols/CowGearPrep.h b/Protocols/CowGearPrep.h index 93c973489..e15d3feab 100644 --- a/Protocols/CowGearPrep.h +++ b/Protocols/CowGearPrep.h @@ -11,6 +11,9 @@ class PairwiseMachine; template class PairwiseGenerator; +/** + * LowGear/CowGear preprocessing + */ template class CowGearPrep : public MaliciousRingPrep { diff --git a/Protocols/FakePrep.h b/Protocols/FakePrep.h index 8597ffc4b..4fd859e5b 100644 --- a/Protocols/FakePrep.h +++ b/Protocols/FakePrep.h @@ -69,6 +69,16 @@ class FakePrep : public BufferPrep } } + void buffer_inputs(int) + { + this->inputs.resize(1); + for (int i = 0; i < 1000; i++) + { + auto r = G.get(); + this->inputs[0].push_back({r, r}); + } + } + void get_dabit_no_count(T& a, typename T::bit_type& b) { auto bit = G.get_bit(); diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 9e650a07d..853782459 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -26,6 +26,8 @@ class FakeProtocol : public ProtocolBase vector trunc_stats; + map cisc_stats; + public: Player& P; @@ -47,6 +49,10 @@ class FakeProtocol : public ProtocolBase } if (expected != 0) cerr << "Expected truncation failures: " << expected << endl; + for (auto& x : cisc_stats) + { + cerr << x.second << " " << x.first << endl; + } } template @@ -189,9 +195,22 @@ class FakeProtocol : public ProtocolBase } void cisc(SubProcessor& processor, const Instruction& instruction) + { + cisc(processor, instruction, T::characteristic_two); + } + + template + void cisc(SubProcessor&, const Instruction&, true_type) + { + throw not_implemented(); + } + + template + void cisc(SubProcessor& processor, const Instruction& instruction, false_type) { int r0 = instruction.get_r(0); string tag((char*)&r0, 4); + cisc_stats[tag.c_str()]++; auto& args = instruction.get_start(); if (tag == string("LTZ\0", 4)) { @@ -225,6 +244,56 @@ class FakeProtocol : public ProtocolBase } } } + else if (tag == "FPDi") + { + for (size_t i = 0; i < args.size(); i += args[i]) + { + assert(i + args[i] <= args.size()); + int f = args.at(i + 6); + for (int j = 0; j < args[i + 1]; j++) + { + auto& res = processor.get_S()[args[i + 2] + j]; + mpf_class a[2]; + for (int k = 0; k < 2; k++) + a[k] = bigint(typename T::clear( + processor.get_S()[args[i + 3 + k] + j])); + if (a[1] != 0) + res = bigint(a[0] / a[1] * exp2(f)); + else + res = 0; + } + } + } + else if (tag == "exp2") + { + for (size_t i = 0; i < args.size(); i += args[i]) + { + assert(i + args[i] <= args.size()); + int f = args.at(i + 5); + for (int j = 0; j < args[i + 1]; j++) + { + auto& res = processor.get_S()[args[i + 2] + j]; + auto a = bigint(typename T::clear( + processor.get_S()[args[i + 3] + j])); + res = bigint(round(exp2(mpf_class(a).get_d() / exp2(f) + f))); + } + } + } + else if (tag == "log2") + { + for (size_t i = 0; i < args.size(); i += args[i]) + { + assert(i + args[i] <= args.size()); + int f = args.at(i + 5); + for (int j = 0; j < args[i + 1]; j++) + { + auto& res = processor.get_S()[args[i + 2] + j]; + auto a = bigint(typename T::clear( + processor.get_S()[args[i + 3] + j])); + res = bigint(round((log2(mpf_class(a).get_d()) - f) * exp2(f))); + } + } + } else throw runtime_error("unknown CISC instruction: " + tag); } diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index bed380ac8..1e8021467 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -9,6 +9,9 @@ #include "SPDZ.h" #include "HemiMatrixPrep.h" +/** + * Matrix multiplication optimized with semi-homomorphic encryption + */ template class Hemi : public SPDZ { diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index d9bca4736..dc285c14c 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -104,6 +104,9 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, return C; } +/** + * Reduce convolution to matrix multiplication + */ template void Hemi::conv2ds(SubProcessor& processor, const Instruction& instruction) diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index 35db214f7..e48d92571 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -11,6 +11,9 @@ template class HemiPrep; +/** + * Semi-honest matrix triple generation using semi-homomorphic encryption + */ template class HemiMatrixPrep : public BufferPrep> { diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index b140e75ed..c43b43e95 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -11,6 +11,9 @@ template class HemiMatrixPrep; +/** + * Semi-honest triple generation with semi-homomorphic encryption (pairwise) + */ template class HemiPrep : public SemiHonestRingPrep { diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index 45dbb9949..d299fb18f 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -22,6 +22,7 @@ class HemiShare : public SemiShare typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; + // matrix multiplication only with prime order field typedef typename conditional, Beaver>::type Protocol; typedef HemiPrep LivePrep; diff --git a/Protocols/HighGearKeyGen.h b/Protocols/HighGearKeyGen.h index 7b9fb78e6..48789239d 100644 --- a/Protocols/HighGearKeyGen.h +++ b/Protocols/HighGearKeyGen.h @@ -33,6 +33,9 @@ class KeyGenBitFactory } }; +/** + * Somewhat homomorphic encryption key generation using MASCOT + */ template class HighGearKeyGen { diff --git a/Protocols/HighGearKeyGen.hpp b/Protocols/HighGearKeyGen.hpp index 41a452456..49fa6702b 100644 --- a/Protocols/HighGearKeyGen.hpp +++ b/Protocols/HighGearKeyGen.hpp @@ -14,6 +14,10 @@ HighGearKeyGen::HighGearKeyGen(Player& P, const FHE_Params& params) : { } +/** + * Generate maBits (authenticated random bits modulo two different primes) + * using daBits (authenticated random bits modulo a large prime and two) + */ template void HighGearKeyGen::buffer_mabits() { diff --git a/Protocols/LowGearKeyGen.h b/Protocols/LowGearKeyGen.h index 534cec856..2101ea6b1 100644 --- a/Protocols/LowGearKeyGen.h +++ b/Protocols/LowGearKeyGen.h @@ -14,6 +14,9 @@ #include "Math/gfp.h" #include "Math/gfpvar.h" +/** + * Homomorphic key component generation (modulo a prime) using MASCOT + */ template class KeyGenProtocol { @@ -55,6 +58,9 @@ class KeyGenProtocol vector& shares); }; +/** + * Semi-homomorphic key generation using MASCOT + */ template class LowGearKeyGen : public KeyGenProtocol<1, L> { diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index d0dd6f340..a59820404 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -79,6 +79,9 @@ void KeyGenProtocol::input(vector& shares, const Rq_Element& shares[j].push_back(inputter.finalize(j)); } +/** + * Binomial secret generation from random bits + */ template template void KeyGenProtocol::binomial(vector_type& shares, T& prep) diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index cf36c98c3..571f391ef 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -23,6 +23,9 @@ using namespace std; #define POPEN_MAX 1000000 +/** + * Sum and broadcast values via a tree of players + */ template class TreeSum { @@ -90,6 +93,9 @@ class Tree_MAC_Check : public TreeSum, public MAC_Check_B void set_random_element(const U& random_element) { (void) random_element; } }; +/** + * SPDZ opening protocol with MAC check (indirect communication) + */ template class MAC_Check_ : public Tree_MAC_Check { @@ -108,6 +114,9 @@ template class Spdz2kShare; template class Spdz2kPrep; template class MascotPrep; +/** + * SPDZ2k opening protocol with MAC check + */ template class MAC_Check_Z2k : public Tree_MAC_Check { @@ -141,6 +150,9 @@ template void add_openings(vector& values, const Player& P, int sum_players, int last_sum_players, int send_player, TreeSum& MC); +/** + * SPDZ opening protocol with MAC check (pairwise communication) + */ template class Direct_MAC_Check: public MAC_Check_ { @@ -257,12 +269,14 @@ void TreeSum::start(vector& values, const Player& P) int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players()); while (true) { + // summing phase int last_sum_players = sum_players; sum_players = (sum_players - 2 + opening_sum) / opening_sum; if (sum_players == 0) break; if (my_relative_num >= sum_players && my_relative_num < last_sum_players) { + // send to the player up the tree for (unsigned int i=0; i::start(vector& values, const Player& P) if (my_relative_num < sum_players) { + // if receiving, add the values timers[RECV_ADD].start(); add_openings(values, P, sum_players, last_sum_players, base_player, *this); timers[RECV_ADD].stop(); @@ -281,6 +296,7 @@ void TreeSum::start(vector& values, const Player& P) if (P.my_num() == base_player) { + // send from the root player os.reset_write_head(); for (unsigned int i=0; i::start(vector& values, const Player& P) } else if (my_relative_num * max_broadcast < P.num_players()) { + // send if there are children int sender = (base_player + my_relative_num / max_broadcast) % P.num_players(); ReceiveValues(values, P, sender); timers[BCAST].start(); @@ -316,6 +333,7 @@ void TreeSum::finish(vector& values, const Player& P) int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players()); if (my_relative_num * max_broadcast >= P.num_players()) { + // receiving at the leafs int sender = (base_player + my_relative_num / max_broadcast) % P.num_players(); ReceiveValues(values, P, sender); } diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index 10cc3b45a..db3f8dc71 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -83,12 +83,6 @@ void Tree_MAC_Check::exchange(const Player& P) popen_cnt += this->values.size(); CheckIfNeeded(P); - - /* not compatible with continuous communication - send_player++; - if (send_player==P.num_players()) - { send_player=0; } - */ } @@ -134,6 +128,7 @@ void MAC_Check_::Check(const Player& P) if (popen_cnt < 10) { + // no random combination with few values vector deltas; Bundle bundle(P); for (int i = 0; i < popen_cnt; i++) @@ -155,6 +150,7 @@ void MAC_Check_::Check(const Player& P) } else { + // check random combination octet seed[SEED_SIZE]; this->timers[SEED].start(); Create_Random_Seed(seed,P,SEED_SIZE); diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index ab0765ccb..c7d477ad4 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -12,6 +12,9 @@ using namespace std; #include "Networking/Player.h" #include "Tools/PointerVector.h" +/** + * Abstract base class for opening protocols + */ template class MAC_Check_Base { @@ -29,24 +32,32 @@ class MAC_Check_Base alphai(mac_key), values_opened(0) {} virtual ~MAC_Check_Base() {} + /// Run checking protocol virtual void Check(const Player& P) { (void)P; } int number() const { return values_opened; } + /// Get MAC key const typename T::mac_key_type::Scalar& get_alphai() const { return alphai; } virtual void POpen_Begin(vector& values,const vector& S,const Player& P); virtual void POpen_End(vector& values,const vector& S,const Player& P); + /// Open values in ``S`` and store results in ``values`` virtual void POpen(vector& values,const vector& S,const Player& P); typename T::open_type POpen(const T& secret, const Player& P); - // alternative name to avoid conflict + /// Open single value typename T::open_type open(const T& secret, const Player& P) { return POpen(secret, P); } + /// Initialize opening round virtual void init_open(const Player& P, int n = 0); + /// Add value to be opened virtual void prepare_open(const T& secret); + /// Run opening protocol virtual void exchange(const Player& P) = 0; + /// Get next opened value virtual typename T::open_type finalize_open(); + /// Check whether all ``shares`` are ``value`` virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); virtual const Player& get_check_player(const Player& P) const { return P; } diff --git a/Protocols/MalRepRingPrep.h b/Protocols/MalRepRingPrep.h index 2ccb32211..ea857a5a4 100644 --- a/Protocols/MalRepRingPrep.h +++ b/Protocols/MalRepRingPrep.h @@ -8,6 +8,10 @@ #include "Protocols/ReplicatedPrep.h" +/** + * Generate random triples with malicious security modulo a power two, + * either via larger modulo or shuffling + */ template class MalRepRingPrep : public virtual BufferPrep { @@ -27,6 +31,9 @@ class MalRepRingPrep : public virtual BufferPrep void buffer_inputs(int player); }; +/** + * Generate random bits from squares modulo a power of two + */ template class RingOnlyBitsFromSquaresPrep : public virtual BufferPrep { diff --git a/Protocols/MaliciousRepMC.h b/Protocols/MaliciousRepMC.h index 2c9d9321c..87deaaa3b 100644 --- a/Protocols/MaliciousRepMC.h +++ b/Protocols/MaliciousRepMC.h @@ -30,6 +30,9 @@ class MaliciousRepMC : public ReplicatedMC } }; +/** + * 3-party replicated opening with checking via hash + */ template class HashMaliciousRepMC : public MaliciousRepMC { diff --git a/Protocols/MaliciousRepPrep.h b/Protocols/MaliciousRepPrep.h index 3e2c8d1c6..be005182a 100644 --- a/Protocols/MaliciousRepPrep.h +++ b/Protocols/MaliciousRepPrep.h @@ -19,6 +19,9 @@ template class PostSacriRepRingShare; template void sacrifice(const vector>& check_triples, Player& P); +/** + * Random bit generation from semi-honest protocol with sacrifice against square + */ template class MaliciousBitOnlyRepPrep : public virtual BufferPrep { @@ -39,6 +42,9 @@ class MaliciousBitOnlyRepPrep : public virtual BufferPrep void init_honest(Player& P); }; +/** + * Random triple/square from semi-honest protocol with sacrifice + */ template class MaliciousRepPrep : public MaliciousBitOnlyRepPrep { diff --git a/Protocols/MaliciousShamirMC.h b/Protocols/MaliciousShamirMC.h index a6b59fae2..a72c36b03 100644 --- a/Protocols/MaliciousShamirMC.h +++ b/Protocols/MaliciousShamirMC.h @@ -8,6 +8,9 @@ #include "ShamirMC.h" +/** + * Shamir share opening with correctness check + */ template class MaliciousShamirMC : public ShamirMC { diff --git a/Protocols/MamaPrep.h b/Protocols/MamaPrep.h index 9c142edc7..6a6bd9634 100644 --- a/Protocols/MamaPrep.h +++ b/Protocols/MamaPrep.h @@ -8,6 +8,9 @@ #include "MascotPrep.h" +/** + * MASCOT triple generation with multiple MACs + */ template class MamaPrep : public OTPrep, public MaliciousRingPrep { diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 4737f39e2..734453d31 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -25,6 +25,9 @@ class OTPrep : public virtual BitPrep NamedCommStats comm_stats(); }; +/** + * MASCOT input tuple generation + */ template class MascotInputPrep : public OTPrep { @@ -38,6 +41,9 @@ class MascotInputPrep : public OTPrep } }; +/** + * MASCOT triple generation + */ template class MascotTriplePrep : public MascotInputPrep { @@ -51,6 +57,9 @@ class MascotTriplePrep : public MascotInputPrep void buffer_triples(); }; +/** + * MASCOT random bit generation + */ template class MascotDabitOnlyPrep : public virtual MaliciousDabitOnlyPrep, public virtual MascotTriplePrep @@ -75,6 +84,9 @@ class MascotDabitOnlyPrep : public virtual MaliciousDabitOnlyPrep, virtual void buffer_bits(); }; +/** + * MASCOT preprocessing with edaBits + */ template class MascotPrep : public virtual MaliciousRingPrep, public virtual MascotDabitOnlyPrep diff --git a/Protocols/NoShare.h b/Protocols/NoShare.h index e44006400..d966f5867 100644 --- a/Protocols/NoShare.h +++ b/Protocols/NoShare.h @@ -10,6 +10,7 @@ #include "Math/bigint.h" #include "Math/gfp.h" #include "GC/NoShare.h" +#include "BMR/Register.h" #include "NoLivePrep.h" #include "NoProtocol.h" diff --git a/Protocols/PostSacrifice.h b/Protocols/PostSacrifice.h index c9ed65b6f..73ec766e4 100644 --- a/Protocols/PostSacrifice.h +++ b/Protocols/PostSacrifice.h @@ -8,6 +8,9 @@ #include "Protocols/Replicated.h" +/** + * Protocol with optimistic multiplication and postponed sacrifice + */ template class PostSacrifice : public ProtocolBase { diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h index f2dbaf7a6..aa0fc7bce 100644 --- a/Protocols/Rep4.h +++ b/Protocols/Rep4.h @@ -8,6 +8,9 @@ #include "Replicated.h" +/** + * Four-party protocol with malicious security via replication + */ template class Rep4 : public ProtocolBase { diff --git a/Protocols/RepRingOnlyEdabitPrep.h b/Protocols/RepRingOnlyEdabitPrep.h index cd324b5a8..205e7d342 100644 --- a/Protocols/RepRingOnlyEdabitPrep.h +++ b/Protocols/RepRingOnlyEdabitPrep.h @@ -8,6 +8,9 @@ #include "ReplicatedPrep.h" +/** + * edaBit generation for replicated secret sharing modulo a power of two + */ template class RepRingOnlyEdabitPrep : public virtual BufferPrep { diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 74d0f0f3b..3de9bfabc 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -26,6 +26,9 @@ template class MAC_Check_Base; template class Preprocessing; class Instruction; +/** + * Base class for replicated three-party protocols + */ class ReplicatedBase { public: @@ -41,6 +44,9 @@ class ReplicatedBase int get_n_relevant_players() { return P.num_players() - 1; } }; +/** + * Abstract base class for multiplication protocols + */ template class ProtocolBase { @@ -67,17 +73,27 @@ class ProtocolBase void multiply(vector& products, vector>& multiplicands, int begin, int end, SubProcessor& proc); + /// Single multiplication T mul(const T& x, const T& y); + /// Initialize multiplication round virtual void init_mul(SubProcessor* proc) = 0; + /// Schedule multiplication of operand pair virtual typename T::clear prepare_mul(const T& x, const T& y, int n = -1) = 0; + /// Run multiplication protocol virtual void exchange() = 0; + /// Get next multiplication result virtual T finalize_mul(int n = -1) = 0; + /// Store next multiplication result in ``res`` virtual void finalize_mult(T& res, int n = -1); + /// Initialize dot product round void init_dotprod(SubProcessor* proc) { init_mul(proc); } + /// Add operand pair to current dot product void prepare_dotprod(const T& x, const T& y) { prepare_mul(x, y); } + /// Finish dot product void next_dotprod() {} + /// Get next dot product result T finalize_dotprod(int length); virtual T get_random(); @@ -106,6 +122,9 @@ class ProtocolBase { throw runtime_error("CISC instructions not implemented"); } }; +/** + * Semi-honest replicated three-party protocol + */ template class Replicated : public ReplicatedBase, public ProtocolBase { diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 29dbbf142..7d62838a3 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -10,6 +10,9 @@ #include "Processor/Processor.h" #include "Replicated.h" +/** + * Base class for input protocols without preprocessing + */ template class PrepLessInput : public InputBase { @@ -33,6 +36,9 @@ class PrepLessInput : public InputBase T finalize_mine(); }; +/** + * Replicated three-party input protocol + */ template class ReplicatedInput : public PrepLessInput { diff --git a/Protocols/ReplicatedMC.h b/Protocols/ReplicatedMC.h index cfcd749ce..bb6f36a20 100644 --- a/Protocols/ReplicatedMC.h +++ b/Protocols/ReplicatedMC.h @@ -8,6 +8,9 @@ #include "MAC_Check_Base.h" +/** + * Replicated semi-honest three-party opening protocol + */ template class ReplicatedMC : public MAC_Check_Base { diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index a8c8266af..8c3ed3f13 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -25,6 +25,9 @@ namespace GC template class ShareThread; } +/** + * Abstract base class for live preprocessing + */ template class BufferPrep : public Preprocessing { @@ -78,8 +81,11 @@ class BufferPrep : public Preprocessing int buffer_size; + /// Key-independent setup if necessary (cryptosystem parameters) static void basic_setup(Player& P) { (void) P; } + /// Generate keys if necessary static void setup(Player& P, typename T::mac_key_type alphai) { (void) P, (void) alphai; } + /// Free memory of global cryptosystem parameters static void teardown() {} static void edabit_sacrifice_buckets(vector>&, size_t, bool, int, @@ -102,6 +108,7 @@ class BufferPrep : public Preprocessing virtual void get_dabit_no_count(T& a, typename T::bit_type& b); + /// Get fresh random value virtual T get_random(); void push_triples(const vector>& triples) @@ -118,6 +125,9 @@ class BufferPrep : public Preprocessing void set_proc(SubProcessor* proc) { this->proc = proc; } }; +/** + * Generic preprocessing protocols + */ template class BitPrep : public virtual BufferPrep { @@ -135,16 +145,23 @@ class BitPrep : public virtual BufferPrep void set_protocol(typename T::Protocol& protocol); + /// Generate squares from triples void buffer_squares(); + /// Generate random bits from inputs without semi-honest security void buffer_bits_without_check(); }; +/** + * Generate (e)daBit protocols + */ template class RingPrep : public virtual BitPrep { typedef typename T::bit_type::part_type BT; + SubProcessor* bit_part_proc; + protected: void buffer_dabits_without_check(vector>& dabits, int buffer_size = -1, ThreadQueues* queues = 0); @@ -169,10 +186,11 @@ class RingPrep : public virtual BitPrep public: RingPrep(SubProcessor* proc, DataPositions& usage); - virtual ~RingPrep() {} + virtual ~RingPrep(); vector& get_bits() { return this->bits; } + /// Generate strict edabits from loose ones template void sanitize(vector>& edabits, int n_bits, int player = -1, ThreadQueues* queues = 0); @@ -180,6 +198,7 @@ class RingPrep : public virtual BitPrep void sanitize(vector>& edabits, int n_bits, int player, int begin, int end); + /// Generic daBit generation with semi-honest security void buffer_dabits_without_check(vector>& dabits, size_t begin, size_t end); template @@ -187,6 +206,7 @@ class RingPrep : public virtual BitPrep size_t begin, size_t end, Preprocessing& bit_prep); + /// Generic edaBit generation with semi-honest security template void buffer_edabits_without_check(int n_bits, vector& sums, vector>& bits, int begin, @@ -198,6 +218,9 @@ class RingPrep : public virtual BitPrep int begin, int end); }; +/** + * Semi-honest *bit preprocessing + */ template class SemiHonestRingPrep : public virtual RingPrep { @@ -229,6 +252,9 @@ class SemiHonestRingPrep : public virtual RingPrep { this->buffer_sedabits_from_edabits(n_bits); } }; +/** + * daBit preprocessing with malicious security + */ template class MaliciousDabitOnlyPrep : public virtual RingPrep { @@ -250,6 +276,9 @@ class MaliciousDabitOnlyPrep : public virtual RingPrep virtual void buffer_dabits(ThreadQueues* queues); }; +/** + * Random bit and edaBit preprocessing with malicious security + */ template class MaliciousRingPrep : public virtual MaliciousDabitOnlyPrep { @@ -301,6 +330,9 @@ class MaliciousRingPrep : public virtual MaliciousDabitOnlyPrep virtual void buffer_edabits(bool strict, int n_bits, ThreadQueues* queues); }; +/** + * Semi-honest preprocessing with honest majority (no (e)daBits) + */ template class ReplicatedRingPrep : public virtual BitPrep { @@ -319,6 +351,9 @@ class ReplicatedRingPrep : public virtual BitPrep virtual void buffer_bits() { this->buffer_bits_without_check(); } }; +/** + * Semi-honest preprocessing with honest majority (including (e)daBits) + */ template class ReplicatedPrep : public virtual ReplicatedRingPrep, public virtual SemiHonestRingPrep diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 27c96c19b..2b8aa1604 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -48,26 +48,33 @@ BufferPrep::BufferPrep(DataPositions& usage) : template BufferPrep::~BufferPrep() { -#ifdef VERBOSE string type_string = T::type_string(); + +#ifdef VERBOSE if (n_bit_rounds > 0) cerr << n_bit_rounds << " rounds of random " << type_string << " bit generation" << endl; +#endif - this->print_left("triples", triples.size() * T::default_length, - type_string); + if (OnlineOptions::singleton.verbose) + { + this->print_left("triples", triples.size() * T::default_length, + type_string); #define X(KIND) \ this->print_left(#KIND, KIND.size(), type_string); - X(squares) X(inverses) X(bits) X(dabits) + X(squares) + X(inverses) + X(bits) + X(dabits) #undef X - for (auto& x : this->edabits) - { - this->print_left_edabits(x.second.size(), x.second[0].size(), - x.first.first, x.first.second); + for (auto& x : this->edabits) + { + this->print_left_edabits(x.second.size(), x.second[0].size(), + x.first.first, x.first.second); + } } -#endif } template @@ -79,8 +86,15 @@ BitPrep::BitPrep(SubProcessor* proc, DataPositions& usage) : template RingPrep::RingPrep(SubProcessor* proc, DataPositions& usage) : - BufferPrep(usage), BitPrep(proc, usage) + BufferPrep(usage), BitPrep(proc, usage), bit_part_proc(0) +{ +} + +template +RingPrep::~RingPrep() { + if (bit_part_proc) + delete bit_part_proc; } template @@ -708,8 +722,10 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, assert(this->protocol != 0); assert(proc != 0); auto &party = GC::ShareThread::s(); - SubProcessor bit_proc(party.MC->get_part_MC(), proc->bit_prep, - proc->P); + if (bit_part_proc == 0) + bit_part_proc = new SubProcessor(party.MC->get_part_MC(), + proc->bit_prep, proc->P); + auto& bit_proc = *bit_part_proc; int n_relevant = this->protocol->get_n_relevant_players(); vector> player_ints(n_relevant, vector(buffer_size)); vector>> parts(n_relevant, @@ -1070,8 +1086,7 @@ void Preprocessing::get_edabits(bool strict, size_t size, T* a, { if (i % unit == 0) Sb[regs[j] + i / unit] = {}; - Sb[regs[j] + i / unit] ^= - (typename T::bit_type(eb.second[j]) << (i % unit)); + Sb[regs[j] + i / unit].xor_bit(i % unit, eb.second[j]); } } } diff --git a/Protocols/ReplicatedPrep2k.h b/Protocols/ReplicatedPrep2k.h index 0c2e9510e..da35865e4 100644 --- a/Protocols/ReplicatedPrep2k.h +++ b/Protocols/ReplicatedPrep2k.h @@ -8,6 +8,9 @@ #include "ReplicatedPrep.h" +/** + * Preprocessing for three-party replicated secret sharing modulo a power of two + */ template class ReplicatedPrep2k : public virtual SemiHonestRingPrep, public virtual ReplicatedRingPrep diff --git a/Protocols/RingOnlyPrep.h b/Protocols/RingOnlyPrep.h index 78fd3415a..2d2f1928f 100644 --- a/Protocols/RingOnlyPrep.h +++ b/Protocols/RingOnlyPrep.h @@ -8,6 +8,9 @@ #include "ReplicatedPrep.h" +/** + * Semi-honest daBit generation for computation modulo a power of two + */ template class RingOnlyPrep : public virtual RingPrep { diff --git a/Protocols/SPDZ.h b/Protocols/SPDZ.h index bc7b5f057..fb2888c05 100644 --- a/Protocols/SPDZ.h +++ b/Protocols/SPDZ.h @@ -15,6 +15,9 @@ template class SubProcessor; template class Share; class Player; +/** + * SPDZ protocol + */ template class SPDZ : public Beaver { diff --git a/Protocols/Semi2k.h b/Protocols/Semi2k.h index 646c955e3..69cf63aad 100644 --- a/Protocols/Semi2k.h +++ b/Protocols/Semi2k.h @@ -9,6 +9,9 @@ #include "SPDZ.h" #include "Processor/TruncPrTuple.h" +/** + * Dishonest-majority protocol for computation modulo a power of two + */ template class Semi2k : public SPDZ { diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 0cc348d95..87a1e08e5 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -10,6 +10,9 @@ template class SemiMC; +/** + * Additive secret sharing input protocol + */ template class SemiInput : public IndividualInput { diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index 67f3b284b..fe4d9db6c 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -9,6 +9,9 @@ #include "MAC_Check.h" #include "Tools/Bundle.h" +/** + * Additive secret sharing opening protocol (indirect communication) + */ template class SemiMC : public TreeSum, public MAC_Check_Base { @@ -26,6 +29,9 @@ class SemiMC : public TreeSum, public MAC_Check_Base SemiMC& get_part_MC() { return *this; } }; +/** + * Additive secret sharing opening protocol (direct communication) + */ template class DirectSemiMC : public SemiMC { diff --git a/Protocols/SemiPrep.h b/Protocols/SemiPrep.h index 7148c8a9b..12e17203f 100644 --- a/Protocols/SemiPrep.h +++ b/Protocols/SemiPrep.h @@ -8,6 +8,9 @@ #include "MascotPrep.h" +/** + * Semi-honest triple generation based on oblivious transfer + */ template class SemiPrep : public virtual OTPrep, public virtual SemiHonestRingPrep { diff --git a/Protocols/SemiPrep2k.h b/Protocols/SemiPrep2k.h index 33ce580c6..50311c594 100644 --- a/Protocols/SemiPrep2k.h +++ b/Protocols/SemiPrep2k.h @@ -9,6 +9,9 @@ #include "SemiPrep.h" #include "RepRingOnlyEdabitPrep.h" +/** + * Preprocessing for additive secret sharing modulo a power of two + */ template class SemiPrep2k : public SemiPrep, public RepRingOnlyEdabitPrep { diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index e9336c753..3d2bf469b 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -19,6 +19,9 @@ template class IndirectShamirMC; class Player; +/** + * Shamir secret sharing-based protocol with resharing + */ template class Shamir : public ProtocolBase { diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 100107469..d387f3b47 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -218,21 +218,21 @@ void Shamir::get_hyper(vector >& hyper, octetStream os; string filename = hyper_filename(t, n); ifstream in(filename); -#ifdef VERBOSE +#ifdef VERBOSE_HYPER cerr << "Trying to load hyper-invertable matrix from " << filename << endl; #endif os.input(in); os.get(hyper); if (int(hyper.size()) != n - t) throw exception(); -#ifdef VERBOSE +#ifdef VERBOSE_HYPER cerr << "Loaded hyper-invertable matrix from " << filename << endl; #endif return; } catch (...) { -#ifdef VERBOSE +#ifdef VERBOSE_HYPER cerr << "Failed to load hyper-invertable" << endl; #endif } diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 5958efc61..023467077 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -11,6 +11,10 @@ #include "ReplicatedInput.h" #include "Machines/ShamirMachine.h" +/** + * Base class for input protocols where the inputting player sends a share + * to every other player + */ template class IndividualInput : public PrepLessInput { @@ -36,6 +40,9 @@ class IndividualInput : public PrepLessInput void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; +/** + * Shamir secret sharing input protocol + */ template class ShamirInput : public IndividualInput { diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index ccd370f21..8f76d6a79 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -11,6 +11,9 @@ #include "Machines/ShamirMachine.h" #include "Tools/Bundle.h" +/** + * Shamir secret sharing opening protocol (indirect communication) + */ template class IndirectShamirMC : public MAC_Check_Base { @@ -24,6 +27,9 @@ class IndirectShamirMC : public MAC_Check_Base virtual void exchange(const Player& P); }; +/** + * Shamir secret sharing opening protocol (direct communication) + */ template class ShamirMC : public IndirectShamirMC { diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h index faaf86875..b8ffd0aaf 100644 --- a/Protocols/ShuffleSacrifice.h +++ b/Protocols/ShuffleSacrifice.h @@ -18,6 +18,9 @@ class Player; template class LimitedPrep; +/** + * Base class for shuffle sacrificing + */ class ShuffleSacrifice { protected: diff --git a/Protocols/SohoPrep.h b/Protocols/SohoPrep.h index 07bc7b9c4..5e28381be 100644 --- a/Protocols/SohoPrep.h +++ b/Protocols/SohoPrep.h @@ -6,6 +6,9 @@ #ifndef PROTOCOLS_SOHOPREP_H_ #define PROTOCOLS_SOHOPREP_H_ +/** + * Semi-honest preprocessing with somewhat homomorphic encryption + */ template class SohoPrep : public SemiHonestRingPrep { diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index ab0cc33ea..33883c66f 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -13,6 +13,9 @@ template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep); +/** + * SPDZ2k preprocessing + */ template class Spdz2kPrep : public virtual MaliciousRingPrep, public virtual MascotTriplePrep, diff --git a/Protocols/SpdzWise.h b/Protocols/SpdzWise.h index cb049fceb..afbf2c850 100644 --- a/Protocols/SpdzWise.h +++ b/Protocols/SpdzWise.h @@ -10,6 +10,9 @@ template class SpdzWiseInput; +/** + * Honest-majority protocol with MAC check + */ template class SpdzWise : public ProtocolBase { diff --git a/Protocols/SpdzWiseInput.h b/Protocols/SpdzWiseInput.h index 458fe02a1..e9597527d 100644 --- a/Protocols/SpdzWiseInput.h +++ b/Protocols/SpdzWiseInput.h @@ -8,6 +8,9 @@ #include "ReplicatedInput.h" +/** + * Honest-majority input protocol with MAC + */ template class SpdzWiseInput : public InputBase { diff --git a/Protocols/SpdzWisePrep.h b/Protocols/SpdzWisePrep.h index 35be8cb94..6b4df251f 100644 --- a/Protocols/SpdzWisePrep.h +++ b/Protocols/SpdzWisePrep.h @@ -10,6 +10,9 @@ template class MaliciousShamirShare; +/** + * Preprocessing for honest-majority protocol with MAC + */ template class SpdzWisePrep : public MaliciousRingPrep { diff --git a/Protocols/SpdzWiseRing.h b/Protocols/SpdzWiseRing.h index 9cc6c12c4..c1c04c192 100644 --- a/Protocols/SpdzWiseRing.h +++ b/Protocols/SpdzWiseRing.h @@ -10,6 +10,9 @@ #include "PostSacrifice.h" #include "PostSacriRepRingShare.h" +/** + * Three-party replicated secret sharing protocol with MAC modulo a power of two + */ template class SpdzWiseRing : public SpdzWise { diff --git a/Protocols/SpdzWiseRingPrep.h b/Protocols/SpdzWiseRingPrep.h index c201c7a6b..4a16b92ee 100644 --- a/Protocols/SpdzWiseRingPrep.h +++ b/Protocols/SpdzWiseRingPrep.h @@ -9,6 +9,10 @@ #include "SpdzWisePrep.h" #include "RepRingOnlyEdabitPrep.h" +/** + * Preprocessing for three-party replicated secret sharing protocol with MAC + * modulo a power of two + */ template class SpdzWiseRingPrep : public virtual SpdzWisePrep, public virtual RepRingOnlyEdabitPrep diff --git a/Protocols/SpdzWiseShare.h b/Protocols/SpdzWiseShare.h index 101965cad..55a19dedc 100644 --- a/Protocols/SpdzWiseShare.h +++ b/Protocols/SpdzWiseShare.h @@ -80,6 +80,7 @@ class SpdzWiseShare : public Share_ { } + void pack(octetStream& os, bool full = true) const; void pack(octetStream& os, open_type factor) const; }; diff --git a/Protocols/SpdzWiseShare.hpp b/Protocols/SpdzWiseShare.hpp index 038556936..6401c083a 100644 --- a/Protocols/SpdzWiseShare.hpp +++ b/Protocols/SpdzWiseShare.hpp @@ -40,6 +40,12 @@ void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& } } +template +void SpdzWiseShare::pack(octetStream& os, bool full) const +{ + super::pack(os, full); +} + template void SpdzWiseShare::pack(octetStream& os, open_type factor) const { diff --git a/Protocols/dabit.h b/Protocols/dabit.h index c6a61fe00..9f7741f0a 100644 --- a/Protocols/dabit.h +++ b/Protocols/dabit.h @@ -27,6 +27,11 @@ class dabit : public pair return T::type_string(); } + static void specification(octetStream& os) + { + T::specification(os); + } + dabit() { } diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 9add273fa..0c209869b 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -6,6 +6,8 @@ using namespace std; #include "Networking/Player.h" +#include "Processor/Data_Files.h" +#include "Math/Setup.h" template void check_share(vector& Sa, typename T::clear& value, @@ -43,15 +45,27 @@ class Files int N; typename T::mac_type key; PRNG G; - Files(int N, const typename T::mac_type& key, const string& prefix) : N(N), key(key) + Files(int N, const typename T::mac_type& key, const string& prep_data_prefix, + Dtype type, int thread_num = -1) : + Files(N, key, + get_prep_sub_dir(prep_data_prefix, N) + + DataPositions::dtype_names[type] + "-" + T::type_short(), + thread_num) + { + } + Files(int N, const typename T::mac_type& key, const string& prefix, + int thread_num = -1) : + N(N), key(key) { outf = new ofstream[N]; for (int i=0; i().output(outf[i]); if (outf[i].fail()) throw file_error(filename.str().c_str()); } diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 89b029b71..951cbfe74 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -485,7 +485,6 @@ inline void check_files(ofstream* outf, int N) /* N = Number players * ntrip = Number triples needed - * str = "2" or "p" */ template void make_mult_triples(const typename T::mac_type& key, int N, int ntrip, @@ -496,44 +495,25 @@ void make_mult_triples(const typename T::mac_type& key, int N, int ntrip, PRNG G; G.ReSeed(); - ofstream* outf=new ofstream[N]; + Files files(N, key, prep_data_prefix, DATA_TRIPLE, thread_num); typename T::clear a,b,c; - vector Sa(N),Sb(N),Sc(N); /* Generate Triples */ - for (int i=0; i(prep_data_prefix, N), DATA_TRIPLE, - T::type_short(), i, thread_num); - cout << "Opening " << filename << endl; - outf[i].open(filename,ios::out | ios::binary); - if (outf[i].fail()) { throw file_error(filename); } - } for (int i=0; i void make_inverse(const typename T::mac_type& key, int N, int ntrip, bool zero, @@ -542,17 +522,8 @@ void make_inverse(const typename T::mac_type& key, int N, int ntrip, bool zero, PRNG G; G.ReSeed(); - ofstream* outf=new ofstream[N]; + Files files(N, key, prep_data_prefix, DATA_INVERSE); typename T::clear a,b; - vector Sa(N),Sb(N); - /* Generate Triples */ - for (int i=0; i(prep_data_prefix, N) << "Inverses-" << T::type_short() << "-P" << i; - cout << "Opening " << filename.str() << endl; - outf[i].open(filename.str().c_str(),ios::out | ios::binary); - if (outf[i].fail()) { throw file_error(filename.str().c_str()); } - } for (int i=0; i&1 | tee logs/$log & true -done - -wait || exit 1 +run_player yao-party.x $* || exit 1 diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 33dd3d6be..75cb8b6ed 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -4,6 +4,7 @@ */ #include "Tools/Buffer.h" +#include "Processor/BaseMachine.h" #include @@ -22,15 +23,30 @@ void BufferBase::setup(ifstream* f, int length, const string& filename, void BufferBase::seekg(int pos) { +#ifdef DEBUG_BUFFER + if (pos != 0) + printf("seek %d %s thread %d\n", pos, filename.c_str(), + BaseMachine::thread_num); +#endif if (not file) - file = open(); - file->seekg(pos * tuple_length); + { + if (pos == 0) + return; + else + file = open(); + } + + file->seekg(header_length + pos * tuple_length); if (file->eof() || file->fail()) { // let it go in case we don't need it anyway if (pos != 0) try_rewind(); } +#ifdef DEBUG_BUFFER + printf("seek %d %d thread %d\n", pos, int(file->tellg()), + BaseMachine::thread_num); +#endif next = BUFFER_SIZE; } @@ -40,10 +56,10 @@ void BufferBase::try_rewind() string type; if (field_type.size() and data_type.size()) type = (string)" of " + field_type + " " + data_type; - throw not_enough_to_buffer(type); + throw not_enough_to_buffer(type, filename); #endif file->clear(); // unset EOF flag - file->seekg(0); + file->seekg(header_length); if (file->peek() == ifstream::traits_type::eof()) throw runtime_error("empty file: " + filename); if (!rewind) @@ -54,18 +70,26 @@ void BufferBase::try_rewind() void BufferBase::prune() { - if (file and not file->good()) + if (file and (not file->good() or file->peek() == EOF)) purge(); - else if (file and file->tellg() != 0) + else if (file and file->tellg() != header_length) { #ifdef VERBOSE cerr << "Pruning " << filename << endl; #endif string tmp_name = filename + ".new"; ofstream tmp(tmp_name.c_str()); + size_t start = file->tellg(); + char buf[header_length]; + file->seekg(0); + file->read(buf, header_length); + tmp.write(buf, header_length); + file->seekg(start); tmp << file->rdbuf(); if (tmp.fail()) - throw runtime_error("problem writing to " + tmp_name); + throw runtime_error( + "problem writing to " + tmp_name + " from " + + to_string(start) + " of " + filename); tmp.close(); file->close(); rename(tmp_name.c_str(), filename.c_str()); diff --git a/Tools/Buffer.h b/Tools/Buffer.h index 84fbdaca9..a95dee0d9 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -13,6 +13,7 @@ using namespace std; #include "Math/field_types.h" #include "Tools/time-func.h" +#include "Tools/octetStream.h" #ifndef BUFFER_SIZE #define BUFFER_SIZE 101 @@ -30,12 +31,13 @@ class BufferBase Timer timer; int tuple_length; string filename; + int header_length; public: bool eof; BufferBase() : file(0), next(BUFFER_SIZE), - tuple_length(-1), eof(false) {} + tuple_length(-1), header_length(0), eof(false) {} ~BufferBase() {} virtual ifstream* open() = 0; void setup(ifstream* f, int length, const string& filename, @@ -63,6 +65,31 @@ class Buffer : public BufferBase void fill_buffer(); }; +template +octetStream file_signature() +{ + octetStream res(T::type_string()); + T::specification(res); + return res; +} + +template +octetStream check_file_signature(ifstream& file, const string& filename) +{ + octetStream file_spec; + try + { + file_spec.input(file); + } + catch (bad_alloc&) + { + throw signature_mismatch(filename); + } + if (file_signature() != file_spec) + throw signature_mismatch(filename); + return file_spec; +} + template class BufferOwner : public Buffer { @@ -88,6 +115,12 @@ class BufferOwner : public Buffer ifstream* open() { file = new ifstream(this->filename, ios::in | ios::binary); + if (file->good()) + { + auto file_spec = check_file_signature(*file, this->filename); + this->header_length = file_spec.get_length() + + sizeof(file_spec.get_length()); + } return file; } @@ -159,6 +192,9 @@ inline void Buffer::read(char* read_buffer) { file->read(read_buffer + n_read, size_in_bytes - n_read); n_read += file->gcount(); +#ifdef DEBUG_BUFFER + fprintf(stderr, "read %d\n", n_read); +#endif if (file->eof()) { try_rewind(); @@ -180,6 +216,9 @@ inline void Buffer::read(char* read_buffer) template inline void Buffer::input(U& a) { +#ifdef DEBUG_BUFFER + fprintf(stderr, "next is %d\n", next); +#endif if (next == BUFFER_SIZE) { fill_buffer(); diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index ee6722e28..96f69b0c5 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -51,11 +51,35 @@ invalid_opcode::invalid_opcode(int opcode) : } input_error::input_error(const char* name, const string& filename, - istream& input_file) + istream& input_file, size_t input_counter) { input_file.clear(); string token; input_file >> token; msg += string() + "cannot read " + name + " from " + filename - + ", problem with '" + token + "'"; + + ", problem with '" + token + "' after " + + to_string(input_counter); +} + +signature_mismatch::signature_mismatch(const string& filename) : + runtime_error("Signature in " + filename + " doesn't match protocol. " + "Re-run preprocessing") +{ +} + +insufficient_memory::insufficient_memory(size_t size, const string& type) : + runtime_error( + "program requires too much " + type + " memory: " + + to_string(size)) +{ +} + +not_enough_to_buffer::not_enough_to_buffer(const string& type, const string& filename) : + runtime_error( + "Not enough data available for buffer" + + (filename.empty() ? "" : (" in " + filename)) + ". " + "Maybe insufficient preprocessing" + type + + ".\nFor benchmarking, you can activate reusing data by " + "adding -DINSECURE to the compiler options.") +{ } diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index fd9820c26..18406cf6c 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -187,14 +187,7 @@ class how_would_that_work : public exception {}; class not_enough_to_buffer : public runtime_error { public: - not_enough_to_buffer(string type) : - runtime_error( - "Not enough data available for buffer. " - "Maybe insufficient preprocessing" + type - + ".\nFor benchmarking, you can activate reusing data by " - "adding -DINSECURE to the compiler options.") - { - } + not_enough_to_buffer(const string& type, const string& filename); }; class needs_cleaning : public exception {}; @@ -265,7 +258,7 @@ class input_error : public exception public: input_error(const char* name, const string& filename, - istream& input_file); + istream& input_file, size_t input_counter); const char* what() const throw() { @@ -273,4 +266,16 @@ class input_error : public exception } }; +class signature_mismatch : public runtime_error +{ +public: + signature_mismatch(const string& filename); +}; + +class insufficient_memory : public runtime_error +{ +public: + insufficient_memory(size_t size, const string& type); +}; + #endif diff --git a/Tools/aes-arm.h b/Tools/aes-arm.h index 0eacbe00a..33f24e883 100644 --- a/Tools/aes-arm.h +++ b/Tools/aes-arm.h @@ -246,22 +246,22 @@ FORCE_INLINE __m128i _mm_aesenclast_si128(__m128i a, __m128i RoundKey) { /* FIXME: optimized for NEON */ uint8_t v[4][4] = { - [0] = {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 0)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 5)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 10)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 15)]}, - [1] = {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 4)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 9)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 14)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 3)]}, - [2] = {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 8)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 13)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 2)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 7)]}, - [3] = {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 12)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 1)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 6)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 11)]}, + {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 0)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 5)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 10)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 15)]}, + {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 4)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 9)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 14)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 3)]}, + {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 8)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 13)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 2)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 7)]}, + {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 12)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 1)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 6)], + SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 11)]}, }; for (int i = 0; i < 16; i++) vreinterpretq_nth_u8_m128i(a, i) = diff --git a/Tools/cpu_support.h b/Tools/cpu_support.h index 755c302d9..aec7d2b3e 100644 --- a/Tools/cpu_support.h +++ b/Tools/cpu_support.h @@ -6,12 +6,19 @@ #ifndef TOOLS_CPU_SUPPORT_H_ #define TOOLS_CPU_SUPPORT_H_ +#include + inline bool check_cpu(int func, bool ecx, int feature) { +#ifdef __aarch64__ + (void) func, (void) ecx, (void) feature; + throw std::runtime_error("only for x86"); +#else int ax = func, bx, cx = 0, dx; __asm__ __volatile__ ("cpuid": "+a" (ax), "=b" (bx), "+c" (cx), "=d" (dx)); return ((ecx ? cx : bx) >> feature) & 1; +#endif } inline bool cpu_has_adx() diff --git a/Utils/Check-Offline-Z2k.cpp b/Utils/Check-Offline-Z2k.cpp index 7994b08b6..f00f6de4e 100644 --- a/Utils/Check-Offline-Z2k.cpp +++ b/Utils/Check-Offline-Z2k.cpp @@ -33,6 +33,9 @@ void check_triples_Z2k(int n_players, string type_char = "") ss << "-P" << i; inputFiles[i].open(ss.str().c_str()); cout << "Opening file " << ss.str() << endl; + octetStream tmp; + tmp.input(inputFiles[i]); + assert(tmp == file_signature()); } int j = 0; diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index 738a58916..3a3644144 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -59,13 +59,9 @@ void check_mult_triples(const typename T::mac_key_type& key,int N,vector void make_square_tuples(const typename T::mac_type& key,int N,int ntrip,const string& str,bool zero) @@ -88,34 +87,18 @@ void make_square_tuples(const typename T::mac_type& key,int N,int ntrip,const st PRNG G; G.ReSeed(); - ofstream* outf=new ofstream[N]; + Files files(N, key, prep_data_prefix, DATA_SQUARE); typename T::clear a,c; - vector Sa(N),Sc(N); /* Generate Squares */ - for (int i=0; i(prep_data_prefix, N) << "Squares-" - << T::type_short() << "-P" << i; - cout << "Opening " << filename.str() << endl; - outf[i].open(filename.str().c_str(),ios::out | ios::binary); - if (outf[i].fail()) { throw file_error(filename.str().c_str()); } - } for (int i=0; i files(N, key, prep_data_prefix, DATA_BIT, thread_num); typename T::clear a; - vector Sa(N); /* Generate Bits */ - for (int i=0; i(prep_data_prefix, N) << "Bits-" - << T::type_short() << "-P" << i - << Sub_Data_Files::get_suffix(thread_num); - cout << "Opening " << filename.str() << endl; - outf[i].open(filename.str().c_str(),ios::out | ios::binary); - if (outf[i].fail()) { throw file_error(filename.str().c_str()); } - } for (int i=0; i @@ -206,8 +174,6 @@ void FakeParams::make_edabits(const typename T::mac_type& key, int N, int ntrip, /* N = Number players * ntrip = Number inputs needed - * str = "2" or "p" - * */ template void make_inputs(const typename T::mac_type& key,int N,int ntrip,const string& str,bool zero) @@ -228,6 +194,7 @@ void make_inputs(const typename T::mac_type& key,int N,int ntrip,const string& s << T::type_short() << "-P" << i << "-" << player; cout << "Opening " << filename.str() << endl; outf[i].open(filename.str().c_str(),ios::out | ios::binary); + file_signature().output(outf[i]); if (outf[i].fail()) { throw file_error(filename.str().c_str()); } } for (int i=0; i>(keytt, nplayers, prep_data_prefix); - make_minimal>(keytt, nplayers, default_num / 64, zero); + make_minimal>(keytt, nplayers, default_num, zero); make_dabits(keyp, nplayers, default_num, zero, keytt); make_edabits(keyp, nplayers, default_num, zero, false_type(), keytt); @@ -804,6 +771,8 @@ int FakeParams::generate() { make_mult_triples>({}, nplayers, default_num, zero, prep_data_prefix); + make_bits>({}, nplayers, + default_num, zero); } generate_field(T::clear::prime_field); diff --git a/Utils/check-passive.cpp b/Utils/check-passive.cpp index dd20bdc86..86923be52 100644 --- a/Utils/check-passive.cpp +++ b/Utils/check-passive.cpp @@ -25,6 +25,9 @@ void check_triples(int n_players, string type_char = "") ss << "-P" << i; inputFiles[i].open(ss.str().c_str()); cout << "Opening file " << ss.str() << endl; + octetStream tmp, tmp2 = file_signature(); + tmp.input(inputFiles[i]); + assert(tmp == tmp2); } int j = 0; @@ -78,6 +81,7 @@ int main(int argc, char** argv) n_players = atoi(argv[1]); read_setup(get_prep_sub_dir>(PREP_DIR, n_players, 128)); gfp::init_field(gfp::pr(), false); + gf2n::init_field(); check_triples(n_players); check_triples(n_players); } diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index 2379d42d8..b3240e086 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -139,7 +139,7 @@ void YaoEvalWire::my_input(T& inputter, bool value, int n_bits) assert(n_bits == 1); auto& inputs = inputter.inputs; size_t start = inputs.size(); - inputs.resize(start + 1); + inputs.resize_zero(start + 1); inputs.set_bit(start, value); } @@ -215,10 +215,32 @@ void YaoEvalWire::set(Key key, bool external) void YaoEvalWire::convcbit(Integer& dest, const GC::Clear& source, GC::Processor>&) { - auto& evaluator = YaoEvaluator::s(); dest = source; - evaluator.P->send_long(0, source.get()); - throw needs_cleaning(); + auto &evaluator = YaoEvaluator::s(); + if (not evaluator.continuous()) + { + evaluator.P->send_long(0, source.get()); + throw needs_cleaning(); + } +} + +void YaoEvalWire::reveal_inst(Processor& processor, const vector& args) +{ + processor.reveal(args); + auto &evaluator = YaoEvaluator::s(); + if (evaluator.continuous()) + { + octetStream buffer; + for (size_t j = 0; j < args.size(); j += 3) + { + int n = args[j]; + int r0 = args[j + 1]; + for (int i = 0; i < DIV_CEIL(n, GC::Clear::N_BITS); i++) + processor.C[r0 + i].pack(buffer); + } + YaoEvaluator::s().P->send_to(0, buffer); + throw needs_cleaning(); + } } template void YaoEvalWire::and_( diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index 18b1e4caf..dc5d45a91 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -59,6 +59,7 @@ class YaoEvalWire : public YaoWire static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>&); + static void reveal_inst(Processor& processor, const vector& args); void set(const Key& key); void set(Key key, bool external); diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp index 0126c8b15..7b1b60154 100644 --- a/Yao/YaoEvaluator.cpp +++ b/Yao/YaoEvaluator.cpp @@ -29,7 +29,8 @@ YaoEvaluator::YaoEvaluator(int thread_num, YaoEvalMaster& master) : void YaoEvaluator::pre_run() { - processor.out.activate(true); + if (master.opts.cmd_private_output_file.empty()) + processor.out.activate(not continuous()); if (not continuous()) receive_to_store(*P); } diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 637ce4b40..7e52602ec 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -8,6 +8,7 @@ #include "YaoGarbler.h" #include "YaoGarbleInput.h" #include "GC/ArgTuples.h" +#include "Tools/pprint.h" #include "GC/Processor.hpp" #include "GC/Secret.hpp" @@ -197,8 +198,35 @@ char YaoGarbleWire::get_output() void YaoGarbleWire::convcbit(Integer& dest, const GC::Clear& source, GC::Processor>&) { - (void) source; - auto& garbler = YaoGarbler::s(); - garbler.untaint(); - dest = garbler.P->receive_long(1); + auto &garbler = YaoGarbler::s(); + if (garbler.continuous()) + dest = source; + else + { + garbler.untaint(); + dest = garbler.P->receive_long(1); + } +} + +void YaoGarbleWire::reveal_inst(Processor& processor, const vector& args) +{ + auto &garbler = YaoGarbler::s(); + if (garbler.continuous()) + { + if (garbler.is_tainted()) + processor.reveal(args); + garbler.untaint(); + octetStream buffer; + garbler.P->receive_player(1, buffer); + for (size_t j = 0; j < args.size(); j += 3) + { + int n = args[j]; + int r0 = args[j + 1]; + for (int i = 0; i < DIV_CEIL(n, GC::Clear::N_BITS); i++) + processor.C[r0 + i].unpack(buffer); + } + garbler.taint(); + } + else + processor.reveal(args); } diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index 885bde7c2..47ebe8e51 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -23,6 +23,7 @@ class YaoGarbleWire : public YaoWire typedef YaoGarbler Party; typedef YaoGarbleInput Input; typedef GC::Processor> Processor; + typedef SwitchableOutput out_type; static string name() { return "YaoGarbleWire"; } @@ -59,6 +60,7 @@ class YaoGarbleWire : public YaoWire static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>&); + static void reveal_inst(Processor& processor, const vector& args); void randomize(PRNG& prng); void set(Key key, bool mask); diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index 53b0401f4..e6ae6cda1 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -30,6 +30,15 @@ YaoGarbler::YaoGarbler(int thread_num, YaoGarbleMaster& master) : prng.ReSeed(); set_n_program_threads(master.machine.nthreads); this->init(*this); + if (continuous()) + taint(); + else + { + processor.out.activate(false); + if (not master.opts.cmd_private_output_file.empty()) + cerr << "Garbling party cannot output with one-shot computation" + << endl; + } } YaoGarbler::~YaoGarbler() diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index fcebe2875..a947b1f51 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -7,39 +7,13 @@ #include "YaoGarbler.h" #include "YaoEvaluator.h" #include "Tools/ezOptionParser.h" +#include "Tools/NetworkOptions.h" #include "GC/Machine.hpp" YaoPlayer::YaoPlayer(int argc, const char** argv) { ez::ezOptionParser opt; - opt.add( - "", // Default. - 1, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "This player's number, 0 for garbling, 1 for evaluating.", // Help description. - "-p", // Flag token. - "--player" // Flag token. - ); - opt.add( - "localhost", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Host where party 0 is running (default: localhost)", // Help description. - "-h", // Flag token. - "--hostname" // Flag token. - ); - opt.add( - "5000", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Base port number (default: 5000).", // Help description. - "-pn", // Flag token. - "--portnum" // Flag token. - ); opt.add( "", // Default. 0, // Required? @@ -59,33 +33,15 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) "--threshold" // Flag token. ); auto& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv}; - opt.parse(argc, argv); - opt.syntax = "./yao-player.x [OPTIONS] "; - vector free_args = opt.firstArgs; - free_args.insert(free_args.end(), opt.unknownArgs.begin(), opt.unknownArgs.end()); - free_args.insert(free_args.end(), opt.lastArgs.begin(), opt.lastArgs.end()); - if (free_args.size() == 2) - { - progname = *free_args[1]; - } - else - { - string usage; - opt.getUsage(usage); - cerr << usage; - exit(1); - } + online_opts = {opt, argc, argv, false_type()}; + NetworkOptionsWithNumber network_opts(opt, argc, argv, 2, false); + online_opts.finalize(opt, argc, argv); - int my_num; - int pnb; - string hostname; + int my_num = online_opts.playerno; int threshold; - opt.get("-p")->getInt(my_num); - opt.get("-pn")->getInt(pnb); - opt.get("-h")->getString(hostname); bool continuous = not opt.get("-O")->isSet; opt.get("-t")->getInt(threshold); + progname = online_opts.progname; GC::ThreadMasterBase* master; if (my_num == 0) @@ -93,7 +49,7 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) else master = new YaoEvalMaster(continuous, online_opts); - Server::start_networking(master->N, my_num, 2, hostname, pnb); + network_opts.start_networking(master->N, my_num); master->run(progname); if (my_num == 1) diff --git a/compile.py b/compile.py index eea78ec8d..da1b69ee3 100755 --- a/compile.py +++ b/compile.py @@ -78,7 +78,7 @@ def main(): "(number of parties as argument)") parser.add_option("-C", "--CISC", action="store_true", dest="cisc", help="faster CISC compilation mode") - parser.add_option("-K", "--keep-cisc", action="store_true", dest="keep_cisc", + parser.add_option("-K", "--keep-cisc", dest="keep_cisc", help="don't translate CISC instructions") parser.add_option("-l", "--flow-optimization", action="store_true", dest="flow_optimization", help="optimize control flow") diff --git a/doc/Doxyfile b/doc/Doxyfile index 7fa2bafe7..771f8cf13 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/_static/custom.css b/doc/_static/custom.css new file mode 100644 index 000000000..d756b5ce6 --- /dev/null +++ b/doc/_static/custom.css @@ -0,0 +1,3 @@ +.wy-table-responsive table td { +white-space: normal; +} diff --git a/doc/add-protocol.rst b/doc/add-protocol.rst index e9cad665b..ff6a05081 100644 --- a/doc/add-protocol.rst +++ b/doc/add-protocol.rst @@ -40,11 +40,13 @@ found in ``Protocols/Replicated.h``. 1. Fill in the :c:func:`constant` static member function of :c:type:`NoShare` as well as the :c:func:`exchange` member function - of c:type:`NoOutput`. Check out + of :c:type:`NoOutput`. Check out :c:func:`DirectSemiMC::exchange_` in ``Protocols/SemiMC.hpp`` for a simple example. It opens an additive secret sharing by sending all shares to all other parties and then summing up the - received. Constant sharing and public output allows to execute the + received. See :ref:`this reference ` for + documentation on the necessary infrastructure. + Constant sharing and public output allows to execute the following program:: print_ln('%s', sint(123).reveal()) diff --git a/doc/conf.py b/doc/conf.py index bf13bf574..a57f08fa0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -186,3 +186,6 @@ breathe_default_project = 'mp-spdz' import subprocess subprocess.call('doxygen', shell=True) + +def setup(app): + app.add_stylesheet('custom.css') diff --git a/doc/gen-instructions.py b/doc/gen-instructions.py index db0c5a2d8..561f25c63 100755 --- a/doc/gen-instructions.py +++ b/doc/gen-instructions.py @@ -24,7 +24,7 @@ for name, opcode in sorted(items, key=lambda x: x[1]): d, n = desc.get(opcode, (None, None)) if d and '$' not in d and '|' not in d and opcode not in \ - (0x64, 0x65, 0x66, 0x6a) : + (0x65, 0x6a) : m = re.split(r'\.\s', d) if m: d = m[0] diff --git a/doc/instructions.rst b/doc/instructions.rst index 795077707..1a833994e 100644 --- a/doc/instructions.rst +++ b/doc/instructions.rst @@ -89,7 +89,6 @@ Compiler.instructions module print_char_regint, protectmemc, sqrs, start_grind, startprivateoutput, stop_grind, stopprivateoutput, writesocketc, writesocketint, - writesockets, readsockets, protectmemint, protectmems, print_mem, matmul_base, g2muls, inputmixed_base, raw_output diff --git a/doc/low-level.rst b/doc/low-level.rst index 5325fd623..0aaf3708e 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -254,3 +254,151 @@ necessary to call the checking in order to verify the outputs. This frees the memory used for global key material when using homomorphic encryption. Otherwise, this does not do anything. + + +Domain Types +------------ + +.. list-table:: + :widths: 20 80 + + * + - ``gfp_`` + - Computation modulo a prime. ``L`` is the number of 64-bit + limbs, that is, it covers primes of bit length + :math:`64(L-1)+1` to :math:`64L`. The type has to be + initialized using :cpp:func:`init_field` or + :cpp:func:`init_default`. The latter picks a prime given a bit length. + * + - ``SignedZ2`` / ``Z2`` + - Computation modulo :math:`2^K`. This is not a field. + * + - ``gf2n_short`` / ``gf2n_long`` / ``gf2n_`` + - :math:`GF(2^n)`. ``T`` denotes a type that is used to store the + values. It must support a variety of integer + operations. The type has to be initialized using + :cpp:func:`init_field`. The choice of degrees is limited. At + the time of writing, 4, 8, 28, 40, 63, and 128 are supported if the + storage type is large enough. + +Share Types +------------ + +.. list-table:: + :widths: 20 80 + :header-rows: 1 + + * + - Type + - Protocol + * + - ``AtlasShare`` + - Semi-honest version of `ATLAS + `_ (Section 4.2). ``T`` must + represent a field. + * + - ``ChaiGearShare`` + - `HighGear `_ with covert key + setup. ``T`` must be ``gfp_`` or ``gf2n_short``. + * + - ``CowGearShare`` + - `LowGear `_ with covert key + setup. ``T`` must be ``gfp_`` or ``gf2n_short``. + * + - ``HemiShare`` + - Semi-honest protocol with Beaver multiplication based on + semi-homomorphic encryption. ``T`` must be ``gfp_`` or + ``gf2n_short``. + * + - ``HighGearShare`` + - `HighGear `_. ``T`` must be + ``gfp_`` or ``gf2n_short``. + * + - ``LowGearShare`` + - `LowGear `_. ``T`` must be + ``gfp_`` or ``gf2n_short``. + * + - ``MaliciousShamirShare`` + - Shamir secret sharing with Beaver multiplication and sacrifice. + ``T`` must represent a field. + * + - ``MamaShare`` + - `MASCOT `_ with multiple + MACs. ``T`` must represent a field, ``N`` is the number of MACs. + * + - ``PostSacriRepFieldShare`` + - `Post-sacrifice `_ protocol + using three-party replicated secret sharing with ``T`` + representing a field. + * + - ``PostSacriRepRingShare`` + - `Post-sacrifice protocol `_ + using replicated three-party secret sharing modulo :math:`2^K` + with security parameter ``S``. + * + - ``Rep3Share2`` + - `Three-party semi-honest protocol + `_ using replicated secret + sharing modulo :math:`2^K`. + * + - ``Rep4Share`` + - `Four-party malicious protocol + `_ using replicated secret + sharing over a field. + * + - ``Rep4Share2`` + - `Four-party malicious protocol + `_ using replicated secret + sharing modulo :math:`2^K`. + * + - ``SemiShare2`` + - Semi-honest dishonest-majority protocol using Beaver + multiplication based on oblivious transfer modulo :math:`2^K`. + * + - ``SemiShare`` + - Semi-honest dishonest-majority protocol using Beaver + multiplication based on oblivious transfer in a field. + * + - ``ShamirShare`` + - `Semi-honest protocol `_ + based on Shamir's secret sharing. ``T`` must represent a field. + * + - ``Share`` + - `MASCOT `_. ``T`` must + represent a field. + * + - ``SohoShare`` + - Semi-honest protocol with Beaver multiplication based on + somewhat homomorphic encryption. ``T`` must be ``gfp_`` + or ``gf2n_short``. + * + - ``Spdz2kShare`` + - `SPDZ2k `_ computing modulo + :math:`2^K` with security parameter ``S``. + * + - ``SpdzWiseShare`` + - `SPDZ-wise `_ computing + modulo :math:`2^K` with security parameter ``S``. + * + - ``SpdzWiseShare`` + - `SPDZ-wise `_. ``T`` must be + ``MaliciousShamirShare`` or ``MaliciousRep3Share``. + + +Protocol Interfaces +------------------- + +.. doxygenclass:: ProtocolBase + :members: + +.. doxygenclass:: InputBase + :members: + +.. doxygenclass:: MAC_Check_Base + :members: + +.. doxygenclass:: Preprocessing + :members: + +.. doxygenclass:: BufferPrep + :members: diff --git a/doc/networking.rst b/doc/networking.rst index 896a4bf5b..16908681a 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -43,6 +43,8 @@ the same on all hosts, and you have to run ``c_rehash Player-Data`` on all of them. +.. _network-reference: + Internal Infrastructure ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/non-linear.rst b/doc/non-linear.rst index e85f2e24d..5fe8df1f6 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -65,8 +65,46 @@ computation for parts of the computation. MP-SPDZ implements protocols proposed for particular security models by a number of works: `Demmler et al. `_, `Mohassel and Rindal `_, and `Dalskov et -al. `_. MP-SPDZ also implements +al. `_ MP-SPDZ also implements more general methods such as `daBits `_ and `edaBits -`_ +`_. + +Protocol Pairs +============== + +The following table lists the matching arithmetic and binary protocols. + +.. list-table:: + :header-rows: 1 + + * + - Arithmetic + - Binary + * + - MASCOT, SPDZ2k, LowGear, HighGear, CowGear, ChaiGear + - `Tinier `_ with improved + cut-and-choose analysis by `Furukawa et + al. `_ + * + - Semi, Hemi, Soho, Semi2k + - SemiBin (Beaver triples modulo 2 using OT) + * + - `Malicious Shamir `_ + - Malicious Shamir over :math:`GF(2^{40})` for secure sacrificing + * + - Malicious Rep3, Post-Sacrifice, SPDZ-wise replicated + - `Malicious Rep3 modulo 2 `_ + * + - `Rep4 `_ + - Rep4 modulo 2 + * + - `Shamir `_ + - Shamir over :math:`GF(2^8)` + * + - `ATLAS `_ + - ATLAS over :math:`GF(2^8)` + * + - `Rep3 `_ + - Rep3 diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 09b877542..1c096d985 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -53,17 +53,30 @@ that computation can seem oddly slow or fast. For example, one multiplication has a similar cost then some thousand multiplications when using homomorphic encryption because one batch contains information for more than than 10,000 multiplications. Only when a -second batch is necessary the cost shoots up. +second batch is necessary the cost shoots up. Other preprocessing +methods allow for a variable batch size, which can be changed using +``-b``. Smaller batch sizes generally reduce the communication cost +while potentially increasing the number of communication rounds. Try +adding ``-b 10`` to the virtal machine (or script) arguments for very +short computations. Handshake failures ~~~~~~~~~~~~~~~~~~ If you run on different hosts, the certificates -(``Player-Data/*.pem``) must be the same on all of them. Also make -sure to run ``c_rehash Player-Data`` on all hosts. Finally, the -certificates generated by ``Scripts/setup-ssl.sh`` expire after a -month, so you might to regenerate them. +(``Player-Data/*.pem``) must be the same on all of them. Furthermore, +party ```` requires ``Player-Data/P.key`` that must match +``Player-Data/P.pem``, that is, they have to be generated to +together. The easiest way of setting this up is to run +``Scripts/setup-ssl.sh`` on one host and then copy all +``Player-Data/*.{pem,key}`` to all other hosts. This is *not* secure +but it suffices for experiments. A secure setup would generate every +key pair locally and then distributed only the public keys. Finally, +run ``c_rehash Player-Data`` on all hosts. The certificates generated +by ``Scripts/setup-ssl.sh`` expire after a month, so you need to +regenerate them. The same holds for ``Scripts/setup-client.sh`` if you +use the client facility. Connection failures From 0603e43375f9a893b8270373b3e8875583ca94e9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 4 Nov 2021 17:37:02 +1100 Subject: [PATCH 003/221] Fix deprecated Sphinx interface. --- doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index a57f08fa0..57f730add 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -188,4 +188,4 @@ subprocess.call('doxygen', shell=True) def setup(app): - app.add_stylesheet('custom.css') + app.add_css_file('custom.css') From ab637517881270b9fdf5e82bfb49ff69b7170f7f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 20 Nov 2021 18:01:26 +1100 Subject: [PATCH 004/221] Instruction output functionality. --- Processor/Instruction.cpp | 32 ++++++++++++ Processor/Instruction.hpp | 20 ++------ Processor/instructions.h | 104 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 16 deletions(-) diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index db03d6d06..68acda3b9 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -7,6 +7,7 @@ #include "instructions.h" #include "Processor.h" #include "Math/gf2n.h" +#include "GC/instructions.h" #include @@ -89,6 +90,37 @@ void Instruction::bitdecint(ArithmeticProcessor& Proc) const } } +ostream& operator<<(ostream& s, const Instruction& instr) +{ + switch (instr.get_opcode()) + { +#define X(NAME, PRE, CODE) \ + case NAME: s << #NAME; break; + ALL_INSTRUCTIONS +#undef X +#define X(NAME, CODE) \ + case NAME: s << #NAME; break; + COMBI_INSTRUCTIONS + } + + s << " size=" << instr.get_size(); + s << " n=" << instr.get_n(); + s << " r=("; + for (int i = 0; i < 3; i++) + s << instr.get_r(i) << ", "; + s << instr.get_r(3); + s << ")"; + if (not instr.get_start().empty()) + { + s << " args=("; + for (unsigned i = 0; i < instr.get_start().size() - 1; i++) + s << instr.get_start()[i] << ", "; + s << instr.get_start().back(); + s << ")"; + } + return s; +} + template void Instruction::execute_clear_gf2n(vector& registers, vector& memory, ArithmeticProcessor& Proc) const; template void Instruction::execute_clear_gf2n(vector& registers, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 72184d9e3..27bec2b37 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -805,22 +805,6 @@ bool BaseInstruction::is_direct_memory_access() const } -inline -ostream& operator<<(ostream& s,const Instruction& instr) -{ - s << instr.opcode << " : "; - for (int i=0; i<3; i++) - { s << instr.r[i] << " "; } - s << " : " << instr.n; - if (instr.start.size()!=0) - { s << " : " << instr.start.size() << " : "; - for (unsigned int i=0; i inline void Instruction::execute(Processor& Proc) const { @@ -1287,6 +1271,10 @@ void Program::execute(Processor& Proc) const Proc.stats[p[Proc.PC].get_opcode()]++; #endif +#ifdef OUTPUT_INSTRUCTIONS + cerr << instruction << endl; +#endif + Proc.PC++; switch(instruction.get_opcode()) diff --git a/Processor/instructions.h b/Processor/instructions.h index 49901822b..5928fdabc 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -280,4 +280,108 @@ X(GRAWOUTPUT, auto source = &C2[r[0]], \ (*source++).output(Proc.public_output, false)) \ +#define REMAINING_INSTRUCTIONS \ + X(CONVMODP, throw not_implemented(),) \ + X(LDMC, throw not_implemented(),) \ + X(LDMCI, throw not_implemented(),) \ + X(STMC, throw not_implemented(),) \ + X(STMCI, throw not_implemented(),) \ + X(MOVC, throw not_implemented(),) \ + X(DIVC, throw not_implemented(),) \ + X(GDIVC, throw not_implemented(),) \ + X(FLOORDIVC, throw not_implemented(),) \ + X(MODC, throw not_implemented(),) \ + X(LEGENDREC, throw not_implemented(),) \ + X(DIGESTC, throw not_implemented(),) \ + X(DIVCI, throw not_implemented(),) \ + X(GDIVCI, throw not_implemented(),) \ + X(INV2M, throw not_implemented(),) \ + X(MODCI, throw not_implemented(),) \ + X(SQUARE, throw not_implemented(),) \ + X(GSQUARE, throw not_implemented(),) \ + X(INV, throw not_implemented(),) \ + X(GINV, throw not_implemented(),) \ + X(RANDOMS, throw not_implemented(),) \ + X(INPUTMASKREG, throw not_implemented(),) \ + X(INPUTMASK, throw not_implemented(),) \ + X(GINPUTMASK, throw not_implemented(),) \ + X(INPUT, throw not_implemented(),) \ + X(GINPUT, throw not_implemented(),) \ + X(INPUTFIX, throw not_implemented(),) \ + X(INPUTFLOAT, throw not_implemented(),) \ + X(INPUTMIXED, throw not_implemented(),) \ + X(INPUTMIXEDREG, throw not_implemented(),) \ + X(RAWINPUT, throw not_implemented(),) \ + X(GRAWINPUT, throw not_implemented(),) \ + X(INPUTPERSONAL, throw not_implemented(),) \ + X(NOTC, throw not_implemented(),) \ + X(SHRSI, throw not_implemented(),) \ + X(OPEN, throw not_implemented(),) \ + X(GOPEN, throw not_implemented(),) \ + X(MULS, throw not_implemented(),) \ + X(GMULS, throw not_implemented(),) \ + X(MULRS, throw not_implemented(),) \ + X(GMULRS, throw not_implemented(),) \ + X(DOTPRODS, throw not_implemented(),) \ + X(GDOTPRODS, throw not_implemented(),) \ + X(MATMULS, throw not_implemented(),) \ + X(MATMULSM, throw not_implemented(),) \ + X(CONV2DS, throw not_implemented(),) \ + X(TRUNC_PR, throw not_implemented(),) \ + X(CHECK, throw not_implemented(),) \ + X(JMP, throw not_implemented(),) \ + X(JMPI, throw not_implemented(),) \ + X(JMPNZ, throw not_implemented(),) \ + X(JMPEQZ, throw not_implemented(),) \ + X(PRINTREG, throw not_implemented(),) \ + X(PRINTREGPLAIN, throw not_implemented(),) \ + X(CONDPRINTPLAIN, throw not_implemented(),) \ + X(PRINTFLOATPLAIN, throw not_implemented(),) \ + X(CONDPRINTSTR, throw not_implemented(),) \ + X(REQBL, throw not_implemented(),) \ + X(GREQBL, throw not_implemented(),) \ + X(USE, throw not_implemented(),) \ + X(USE_INP, throw not_implemented(),) \ + X(USE_EDABIT, throw not_implemented(),) \ + X(USE_MATMUL, throw not_implemented(),) \ + X(USE_PREP, throw not_implemented(),) \ + X(GUSE_PREP, throw not_implemented(),) \ + X(TIME, throw not_implemented(),) \ + X(START, throw not_implemented(),) \ + X(STOP, throw not_implemented(),) \ + X(RUN_TAPE, throw not_implemented(),) \ + X(JOIN_TAPE, throw not_implemented(),) \ + X(CRASH, throw not_implemented(),) \ + X(STARTGRIND, throw not_implemented(),) \ + X(STOPGRIND, throw not_implemented(),) \ + X(NPLAYERS, throw not_implemented(),) \ + X(THRESHOLD, throw not_implemented(),) \ + X(PLAYERID, throw not_implemented(),) \ + X(LISTEN, throw not_implemented(),) \ + X(ACCEPTCLIENTCONNECTION, throw not_implemented(),) \ + X(CLOSECLIENTCONNECTION, throw not_implemented(),) \ + X(READSOCKETINT, throw not_implemented(),) \ + X(READSOCKETC, throw not_implemented(),) \ + X(READSOCKETS, throw not_implemented(),) \ + X(WRITESOCKETINT, throw not_implemented(),) \ + X(WRITESOCKETC, throw not_implemented(),) \ + X(WRITESOCKETS, throw not_implemented(),) \ + X(WRITESOCKETSHARE, throw not_implemented(),) \ + X(WRITEFILESHARE, throw not_implemented(),) \ + X(READFILESHARE, throw not_implemented(),) \ + X(PUBINPUT, throw not_implemented(),) \ + X(RAWOUTPUT, throw not_implemented(),) \ + X(INTOUTPUT, throw not_implemented(),) \ + X(FLOATOUTPUT, throw not_implemented(),) \ + X(STARTPRIVATEOUTPUT, throw not_implemented(),) \ + X(GSTARTPRIVATEOUTPUT, throw not_implemented(),) \ + X(STOPPRIVATEOUTPUT, throw not_implemented(),) \ + X(GSTOPPRIVATEOUTPUT, throw not_implemented(),) \ + X(PREP, throw not_implemented(),) \ + X(GPREP, throw not_implemented(),) \ + X(CISC, throw not_implemented(),) \ + +#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \ + CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS + #endif /* PROCESSOR_INSTRUCTIONS_H_ */ From eac6456ec85213c97d0d6004478b3c70d0edf489 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 22 Nov 2021 17:50:58 +1100 Subject: [PATCH 005/221] Allow preprocessing information to be supplied via named pipes. --- Processor/Data_Files.hpp | 5 ++- Processor/OnlineOptions.cpp | 16 +++++++ Processor/OnlineOptions.h | 1 + Processor/PrepBase.cpp | 10 ++++- Programs/Source/test_thread_mul.mpc | 11 +++++ Scripts/test_streaming.sh | 17 ++++++++ Tools/Buffer.cpp | 20 ++++++++- Tools/Buffer.h | 1 + Utils/stream-fake-mascot-triples.cpp | 65 ++++++++++++++++++++++++++++ 9 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 Programs/Source/test_thread_mul.mpc create mode 100755 Scripts/test_streaming.sh create mode 100644 Utils/stream-fake-mascot-triples.cpp diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 6951ed2cc..3635dc0ac 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -26,7 +26,7 @@ Preprocessing* Preprocessing::get_new( return get_live_prep(proc, usage); else return new Sub_Data_Files(machine.get_N(), - machine.template prep_dir_prefix(), usage); + machine.template prep_dir_prefix(), usage, BaseMachine::thread_num); } template @@ -185,6 +185,9 @@ Sub_Data_Files::~Sub_Data_Files() template void Sub_Data_Files::seekg(DataPositions& pos) { + if (OnlineOptions::singleton.file_prep_per_thread) + return; + if (T::LivePrep::use_part) { get_part().seekg(pos); diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 03fa23793..41308603b 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -28,6 +28,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) bucket_size = 4; cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; + file_prep_per_thread = false; #ifdef VERBOSE verbose = true; #else @@ -170,6 +171,16 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--live-preprocessing" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Preprocessing from files by thread (use with pipes)", // Help description. + "-f", // Flag token. + "--file-prep-per-thread" // Flag token. + ); + opt.add( to_string(default_batch_size).c_str(), // Default. 0, // Required? @@ -224,6 +235,11 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, live_prep = not opt.get("-F")->isSet; else live_prep = opt.get("-L")->isSet; + if (opt.isSet("-f")) + { + live_prep = false; + file_prep_per_thread = true; + } opt.get("-b")->getInt(batch_size); opt.get("--memory")->getString(memtype); bits_from_squares = opt.isSet("-Q"); diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 32c80fc2b..de8f1e722 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -29,6 +29,7 @@ class OnlineOptions std::string cmd_private_input_file; std::string cmd_private_output_file; bool verbose; + bool file_prep_per_thread; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index b99d091e6..5c44b9087 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -6,11 +6,17 @@ #include "PrepBase.h" #include "Data_Files.h" +#include "OnlineOptions.h" string PrepBase::get_suffix(int thread_num) { - (void) thread_num; - return ""; + if (OnlineOptions::singleton.file_prep_per_thread) + { + assert(thread_num >= 0); + return "-T" + to_string(thread_num); + } + else + return ""; } string PrepBase::get_filename(const string& prep_data_dir, diff --git a/Programs/Source/test_thread_mul.mpc b/Programs/Source/test_thread_mul.mpc new file mode 100644 index 000000000..42097fd1a --- /dev/null +++ b/Programs/Source/test_thread_mul.mpc @@ -0,0 +1,11 @@ +n = 1000000 +x = sint.Array(n) +x.assign_vector(regint.inc(n)) + +@multithread(2, n) +def _(base, size): + x.assign_vector(x.get_vector(base, size) ** 2, base) + +print_ln('%s', x[2].reveal()) +crash(x[2].reveal() != 4) +crash(x[n - 1].reveal() != (n - 1) ** 2) diff --git a/Scripts/test_streaming.sh b/Scripts/test_streaming.sh new file mode 100755 index 000000000..0ff2fb336 --- /dev/null +++ b/Scripts/test_streaming.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +make stream-fake-mascot-triples.x +./compile.py test_thread_mul || exit 1 + +rm Player-Data/2-p-128/Triples-p-P?-T? +mkdir Player-Data/2-p-128 + +for i in 0 1; do + for j in 0 1 2; do + mknod Player-Data/2-p-128/Triples-p-P$i-T$j p || exit 1 + done +done + +./stream-fake-mascot-triples.x & + +Scripts/mascot.sh test_thread_mul -f || exit 1 diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 75cb8b6ed..c669081f8 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -7,6 +7,8 @@ #include "Processor/BaseMachine.h" #include +#include +#include bool BufferBase::rewind = false; @@ -21,8 +23,19 @@ void BufferBase::setup(ifstream* f, int length, const string& filename, this->filename = filename; } +bool BufferBase::is_pipe() +{ + struct stat buf; + if (stat(filename.c_str(), &buf)) + return S_ISFIFO(buf.st_mode); + else + return false; +} + void BufferBase::seekg(int pos) { + assert(not is_pipe()); + #ifdef DEBUG_BUFFER if (pos != 0) printf("seek %d %s thread %d\n", pos, filename.c_str(), @@ -52,6 +65,8 @@ void BufferBase::seekg(int pos) void BufferBase::try_rewind() { + assert(not is_pipe()); + #ifndef INSECURE string type; if (field_type.size() and data_type.size()) @@ -70,6 +85,9 @@ void BufferBase::try_rewind() void BufferBase::prune() { + if (is_pipe()) + return; + if (file and (not file->good() or file->peek() == EOF)) purge(); else if (file and file->tellg() != header_length) @@ -99,7 +117,7 @@ void BufferBase::prune() void BufferBase::purge() { - if (file) + if (file and not is_pipe()) { #ifdef VERBOSE cerr << "Removing " << filename << endl; diff --git a/Tools/Buffer.h b/Tools/Buffer.h index a95dee0d9..941ec4256 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -44,6 +44,7 @@ class BufferBase const char* type = "", const string& field = {}); void seekg(int pos); bool is_up() { return file != 0; } + bool is_pipe(); void try_rewind(); void prune(); void purge(); diff --git a/Utils/stream-fake-mascot-triples.cpp b/Utils/stream-fake-mascot-triples.cpp new file mode 100644 index 000000000..5aa85a054 --- /dev/null +++ b/Utils/stream-fake-mascot-triples.cpp @@ -0,0 +1,65 @@ +/* + * stream-fake-mascot-triples.cpp + * + */ + +#include "Protocols/Share.h" +#include "Math/gfpvar.h" +#include "Tools/benchmarking.h" + +#include "Math/Setup.hpp" +#include "Protocols/fake-stuff.hpp" + +class Info +{ +public: + int thread_num; + int nplayers; + gfpvar key; + pthread_t thread; +}; + +void* run(void* arg) +{ + auto& info = *(Info*) arg; + Files> files(info.nplayers, info.key, PREP_DIR, DATA_TRIPLE, info.thread_num); + SeededPRNG G; + int count = 0; + while (true) + { + gfpvar triple[3]; + for (int i = 0; i < 2; i++) + triple[i].randomize(G); + triple[2] = triple[0] * triple[1]; + for (int i = 0; i < 3; i++) + files.output_shares(triple[i]); + count++; + } + cerr << "failed after " << count << endl; + return 0; +} + +int main() +{ + insecure("preprocessing"); + typedef Share T; + int nplayers = 2; + int lgp = 128; + string prep_data_prefix = PREP_DIR; + gfpvar::generate_setup(prep_data_prefix, nplayers, lgp); + T::mac_key_type keyp; + generate_mac_keys(keyp, nplayers, prep_data_prefix); + + int nthreads = 3; + OnlineOptions::singleton.file_prep_per_thread = true; + vector infos(3); + for (int i = 0; i < nthreads; i++) + { + auto& info = infos[i]; + info.thread_num = i; + info.nplayers = nplayers; + info.key = keyp; + pthread_create(&info.thread, 0, run, &info); + } + pthread_join(infos[0].thread, 0); +} From 10f43e281e7a13f153dece6b27b015da38548b9f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 23 Nov 2021 21:19:52 +1100 Subject: [PATCH 006/221] Fix cleartext comparisons with larger primes. --- Compiler/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 151c805ea..ec062b5cd 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1011,8 +1011,9 @@ def less_than(self, other, bit_length): return regint(self) < regint(other) else: diff = self - other + diff += (1 << (bit_length - 1)) shifted = diff >> (bit_length - 1) - res = regint(shifted & 1) + res = 1 - regint(shifted & 1) return res def __lt__(self, other): From 40431cd52a2dfb179ed958ef058e27bfbedc64f0 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 25 Nov 2021 11:15:54 +1100 Subject: [PATCH 007/221] Fix bug in early abort. --- Compiler/ml.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index a8be2d533..5ff1a3753 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -2011,7 +2011,8 @@ def _(j): i.iadd(1) res = True if self.tol > 0: - res *= (1 - (loss >= 0) * (loss < self.tol)).reveal() + res *= (1 - (loss_sum >= 0) * \ + (loss_sum < self.tol * n_per_epoch)).reveal() return res def reveal_correctness(self, data, truth, batch_size): From 00a2bcba45ea77fe490b30f5b23cb1b6cb58f875 Mon Sep 17 00:00:00 2001 From: rtaiello <41542771+rtaiello@users.noreply.github.com> Date: Mon, 29 Nov 2021 14:11:38 +0100 Subject: [PATCH 008/221] Fix problem with Scripts/run-common.sh execution --- ExternalIO/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ExternalIO/README.md b/ExternalIO/README.md index 02b4e1e8e..36649f5c7 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -14,7 +14,7 @@ make bankers-bonus-client.x ./compile.py bankers_bonus 1 Scripts/setup-ssl.sh Scripts/setup-clients.sh 3 -Scripts/.sh & +Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 0 100 0 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 From ae99bed1927dbefd36258971001339587fc3410f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 2 Dec 2021 18:12:11 +1100 Subject: [PATCH 009/221] Access clear bit memory by run-time indices. --- Compiler/GC/instructions.py | 22 ++++++++++++++++++++++ Compiler/GC/types.py | 4 ++-- GC/Instruction.h | 2 ++ GC/instructions.h | 2 ++ Processor/Instruction.hpp | 2 ++ 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index f1b5ad236..fc64ae2d2 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -64,6 +64,8 @@ class ClearBitsAF(base.RegisterArgFormat): MULCBI = 0x21c, SHRCBI = 0x21d, SHLCBI = 0x21e, + LDMCBI = 0x258, + STMCBI = 0x259, CONVCINTVEC = 0x21f, PRINTREGSIGNED = 0x220, PRINTREGB = 0x221, @@ -360,6 +362,26 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): code = opcodes['STMSBI'] arg_format = ['sb','ci'] +class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): + """ Copy clear bit memory cell with run-time address to clear bit + register. + + :param: destination (cbit) + :param: memory address (regint) + """ + code = opcodes['LDMCBI'] + arg_format = ['cbw','ci'] + +class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): + """ Copy clear bit register to clear bit memory cell with run-time + address. + + :param: source (cbit) + :param: memory address (regint) + """ + code = opcodes['STMCBI'] + arg_format = ['cb','ci'] + class ldmsdi(base.ReadMemoryInstruction): code = opcodes['LDMSDI'] arg_format = tools.cycle(['sbw','cb','int']) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 6a5e39f13..5a65e73ac 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -228,8 +228,8 @@ class cbits(bits): max_length = 64 reg_type = 'cb' is_clear = True - load_inst = (None, inst.ldmcb) - store_inst = (None, inst.stmcb) + load_inst = (inst.ldmcbi, inst.ldmcb) + store_inst = (inst.stmcbi, inst.stmcb) bitdec = inst.bitdecc conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y)) conv_cint_vec = inst.convcintvec diff --git a/GC/Instruction.h b/GC/Instruction.h index f8872f12f..e990f954e 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -81,6 +81,8 @@ enum SHRCBI = 0x21d, SHLCBI = 0x21e, CONVCINTVEC = 0x21f, + LDMCBI = 0x258, + STMCBI = 0x259, // don't write PRINTREGSIGNED = 0x220, PRINTREGB = 0x221, diff --git a/GC/instructions.h b/GC/instructions.h index f94da799a..fb44e4e08 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -61,6 +61,8 @@ X(STMCB, PROC.mem_op(SIZE, MMC, PROC.C, IMM, R0)) \ X(LDMSBI, PROC.mem_op(SIZE, PROC.S, MMS, R0, Ci[REG1])) \ X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \ + X(LDMCBI, PROC.mem_op(SIZE, PROC.C, MMC, R0, Ci[REG1])) \ + X(STMCBI, PROC.mem_op(SIZE, MMC, PROC.C, Ci[REG1], R0)) \ X(MOVSB, S0 = PS1) \ X(TRANS, T::trans(PROC, IMM, EXTRA)) \ X(BITB, PROC.random_bit(S0)) \ diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 27bec2b37..e516fdf37 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -101,6 +101,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case STMSI: case LDMSBI: case STMSBI: + case LDMCBI: + case STMCBI: case MOVC: case MOVS: case MOVSB: From b771417e04aa60582a357de14331eaa59a23c101 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 4 Dec 2021 20:38:56 +1100 Subject: [PATCH 010/221] Clear to secret bit conversion with Yao's garbled circuits. --- BMR/Register.h | 5 +++++ GC/FakeSecret.h | 4 ++++ GC/Processor.h | 3 +++ GC/Processor.hpp | 12 ++++++++++++ GC/Secret.h | 5 +++++ GC/ShareSecret.h | 5 +++++ GC/instructions.h | 3 ++- Processor/Processor.hpp | 11 ----------- Yao/YaoEvalWire.cpp | 13 +++++++++++++ Yao/YaoEvalWire.h | 6 ++++++ Yao/YaoGarbleWire.cpp | 15 +++++++++++++++ Yao/YaoGarbleWire.h | 6 ++++++ 12 files changed, 76 insertions(+), 12 deletions(-) diff --git a/BMR/Register.h b/BMR/Register.h index 4d0c1b074..d0a75e930 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -23,6 +23,7 @@ using namespace std; #include "Tools/PointerVector.h" #include "Tools/Bundle.h" #include "Tools/SwitchableOutput.h" +#include "Processor/Instruction.h" //#define PAD_TO_8(n) (n+8-n%8) #define PAD_TO_8(n) (n) @@ -289,6 +290,10 @@ class ProgramRegister : public Phase, public Register static void inputbvec(T& processor, ProcessorBase& input_processor, const vector& args); + template + static void convcbit2s(GC::Processor&, const BaseInstruction&) + { throw runtime_error("convcbit2s not implemented"); } + // most BMR phases don't need actual input template static T get_input(GC::Processor& processor, const InputArgs& args) diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 73013efa5..55c537de3 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -105,6 +105,10 @@ class FakeSecret : public ShareInterface, public BitVec template static void convcbit(Integer& dest, const Clear& source, T&) { dest = source; } + template + static void convcbit2s(GC::Processor&, const BaseInstruction&) + { throw runtime_error("convcbit2s not implemented"); } + static FakeSecret input(GC::Processor& processor, const InputArgs& args); static FakeSecret input(int from, word input, int n_bits); diff --git a/GC/Processor.h b/GC/Processor.h index 2dddf8df2..3cc0c509f 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -100,6 +100,9 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching void reveal(const vector& args); + template + void convcbit2s(const BaseInstruction& instruction); + void print_reg(int reg, int n, int size); void print_reg_plain(Clear& value); void print_reg_signed(unsigned n_bits, Integer value); diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 016352d3e..663d55fcb 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -331,6 +331,18 @@ void Processor::reveal(const vector& args) } } +template +template +void Processor::convcbit2s(const BaseInstruction& instruction) +{ + int unit = GC::Clear::N_BITS; + auto& share_thread = ShareThread::s(); + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + S[instruction.get_r(0) + i] = T::constant(C[instruction.get_r(1) + i], + share_thread.P->my_num(), share_thread.MC->get_alphai(), + min(unsigned(unit), instruction.get_n() - i * unit)); +} + template void Processor::print_reg(int reg, int n, int size) { diff --git a/GC/Secret.h b/GC/Secret.h index 6b37aa21c..14f6638af 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -18,6 +18,7 @@ #include "Math/gf2nlong.h" #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "Tools/FixedVector.h" @@ -122,6 +123,10 @@ class Secret Processor& proc) { T::convcbit(dest, source, proc); } + template + static void convcbit2s(Processor& processor, const BaseInstruction& instruction) + { T::convcbit2s(processor, instruction); } + Secret(); Secret(const Integer& x) { *this = x; } diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 9ea0d2f68..10cf65c04 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -21,6 +21,7 @@ using namespace std; #include "Protocols/ReplicatedMC.h" #include "Processor/DummyProtocol.h" #include "Processor/ProcessorBase.h" +#include "Processor/Instruction.h" namespace GC { @@ -74,6 +75,10 @@ class ShareSecret template static void convcbit(Integer& dest, const Clear& source, T&) { dest = source; } + template + static void convcbit2s(Processor& processor, const BaseInstruction& instruction) + { processor.convcbit2s(instruction); } + static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } void check_length(int n, const Integer& x); diff --git a/GC/instructions.h b/GC/instructions.h index fb44e4e08..fc278d441 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -82,7 +82,7 @@ X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ X(CONVCINTVEC, Proc.convcintvec(instruction)) \ X(CONVCBITVEC, Proc.convcbitvec(instruction)) \ - X(CONVCBIT2S, Proc.convcbit2s(instruction)) \ + X(CONVCBIT2S, PROC.convcbit2s(instruction)) \ X(DABIT, Proc.dabit(INST)) \ X(EDABIT, Proc.edabit(INST)) \ X(SEDABIT, Proc.edabit(INST, true)) \ @@ -99,6 +99,7 @@ X(CONVSINT, S0.load_clear(IMM, PI1)) \ X(CONVCINT, C0 = PI1) \ X(CONVCBIT, T::convcbit(I0, PC1, PROC)) \ + X(CONVCBIT2S, T::convcbit2s(PROC, instruction)) \ X(PRINTCHR, PROC.print_chr(IMM)) \ X(PRINTSTR, PROC.print_str(IMM)) \ X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \ diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index ebbc1c8cc..6206e27c2 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -219,17 +219,6 @@ void Processor::convcintvec(const Instruction& instruction) } } -template -void Processor::convcbit2s(const Instruction& instruction) -{ - int unit = GC::Clear::N_BITS; - for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) - Procb.S[instruction.get_r(0) + i] = sint::bit_type::constant( - Procb.C[instruction.get_r(1) + i], P.my_num(), - share_thread.MC->get_alphai(), - min(unsigned(unit), instruction.get_n() - i * unit)); -} - template void Processor::split(const Instruction& instruction) { diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index b3240e086..38cdc922e 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -243,6 +243,19 @@ void YaoEvalWire::reveal_inst(Processor& processor, const vector& args) } } +void YaoEvalWire::convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction) +{ + int unit = GC::Clear::N_BITS; + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + { + auto& dest = processor.S[instruction.get_r(0) + i]; + dest.resize_regs(min(unsigned(unit), instruction.get_n() - i * unit)); + for (auto& reg : dest.get_regs()) + reg.set(0); + } +} + template void YaoEvalWire::and_( GC::Processor >& processor, const vector& args); diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index dc5d45a91..796d35615 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -10,6 +10,7 @@ #include "BMR/Gate.h" #include "BMR/Register.h" #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "config.h" #include "YaoWire.h" @@ -19,6 +20,8 @@ class ProcessorBase; class YaoEvalWire : public YaoWire { + typedef GC::Secret whole_type; + public: typedef YaoEvaluator Party; typedef YaoEvalInput Input; @@ -61,6 +64,9 @@ class YaoEvalWire : public YaoWire GC::Processor>&); static void reveal_inst(Processor& processor, const vector& args); + static void convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction); + void set(const Key& key); void set(Key key, bool external); diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 7e52602ec..37931df43 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -230,3 +230,18 @@ void YaoGarbleWire::reveal_inst(Processor& processor, const vector& args) else processor.reveal(args); } + +void YaoGarbleWire::convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction) +{ + int unit = GC::Clear::N_BITS; + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + { + auto& dest = processor.S[instruction.get_r(0) + i]; + int n = min(unsigned(unit), instruction.get_n() - i * unit); + dest.resize_regs(n); + for (int j = 0; j < n; j++) + dest.get_reg(i).public_input( + processor.C[instruction.get_r(1) + i].get_bit(j)); + } +} diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index 47ebe8e51..cc1ba8ce2 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -10,6 +10,7 @@ #include "BMR/Register.h" #include "config.h" #include "YaoWire.h" +#include "Processor/Instruction.h" #include @@ -19,6 +20,8 @@ class ProcessorBase; class YaoGarbleWire : public YaoWire { + typedef GC::Secret whole_type; + public: typedef YaoGarbler Party; typedef YaoGarbleInput Input; @@ -62,6 +65,9 @@ class YaoGarbleWire : public YaoWire GC::Processor>&); static void reveal_inst(Processor& processor, const vector& args); + static void convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction); + void randomize(PRNG& prng); void set(Key key, bool mask); From e76014e2e9f9a0b8d6c3dc629c80ae3410ad1e70 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 7 Dec 2021 16:42:18 +1100 Subject: [PATCH 011/221] More parallelized SSL handshake. --- Networking/CryptoPlayer.cpp | 29 +++++++++++++++++++++++------ Networking/CryptoPlayer.h | 2 ++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index d0b289b36..c2e1403b9 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -68,7 +68,20 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : player.sockets.clear(); } - for (int i = 0; i < (int)sockets.size(); i++) + for (int offset = 1; offset <= num_players() / 2; offset++) + { + int others[] = { get_player(offset), get_player(-offset) }; + if (my_num() % (2 * offset) < offset) + swap(others[0], others[1]); + + if (num_players() % 2 == 0 and offset == num_players() / 2) + connect(others[0], plaintext_sockets); + else + for (int i = 0; i < 2; i++) + connect(others[i], plaintext_sockets); + } + + for (int i = 0; i < num_players(); i++) { if (i == my_num()) { @@ -79,16 +92,20 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : continue; } - sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[0][i], - "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); - other_sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[1][i], - "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); - senders[i] = new Sender(i < my_num() ? sockets[i] : other_sockets[i]); receivers[i] = new Receiver(i < my_num() ? other_sockets[i] : sockets[i]); } } +void CryptoPlayer::connect(int i, vector* plaintext_sockets) +{ + sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[0][i], + "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); + other_sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[1][i], + "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); + +} + CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) : CryptoPlayer(Nms, to_string(id_base)) { diff --git a/Networking/CryptoPlayer.h b/Networking/CryptoPlayer.h index 287f5c66f..3ab3ed565 100644 --- a/Networking/CryptoPlayer.h +++ b/Networking/CryptoPlayer.h @@ -28,6 +28,8 @@ class CryptoPlayer : public MultiPlayer vector*> senders; vector*> receivers; + void connect(int other, vector* plaintext_sockets); + public: /** * Start a new set of encrypted connections. From 047a4775bf2cc0f328f648d9749d322a8330a023 Mon Sep 17 00:00:00 2001 From: HaoXuan40404 <444649358@qq.com> Date: Wed, 8 Dec 2021 14:18:32 +0800 Subject: [PATCH 012/221] init rc2 network --- Compiler/ppc.py | 2 +- Makefile | 12 ++++++++++++ Networking/Player.cpp | 33 +++++++++++++++++++++++++++++++-- Networking/Player.h | 7 ++++++- Networking/Server.cpp | 6 ++++++ Networking/ServerSocket.cpp | 2 +- Processor/OnlineMachine.hpp | 1 + Programs/Source/ppc_circuit.mpc | 3 ++- 8 files changed, 60 insertions(+), 6 deletions(-) diff --git a/Compiler/ppc.py b/Compiler/ppc.py index 70c3e851a..4950b143d 100644 --- a/Compiler/ppc.py +++ b/Compiler/ppc.py @@ -15,7 +15,7 @@ print_float_prec(7) # Use to limit the tester workload -MAX_DATA_LENGTH = 500 +MAX_DATA_LENGTH = 500000 MAX_ML_SIZE = 10000 SECOND_LOOP_SIZE = 1000 diff --git a/Makefile b/Makefile index fcfbee413..67114b631 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,12 @@ GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) OT_EXE = ot.x ot-offline.x +CFLAGS += -DVERBOSE_COMM +CFLAGS += -DDEBUG_THREADS +CFLAGS += -DDEBUG_THREAD_QUEUE +CFLAGS += -DDEBUG_NETWORKING +CFLAGS += -DINSECURE +CFLAGS += -DPPC_COMMUNICATION COMMONOBJS = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o Processor/DataPositions.o Processor/ThreadQueues.o Processor/ThreadQueue.o COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) @@ -165,6 +171,12 @@ bmr-clean: bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) +ppc-receiver-client.x: ExternalIO/ppc-receiver-client.cpp $(COMMON) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +ppc-sender-client.x: ExternalIO/ppc-sender-client.cpp $(COMMON) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + simple-offline.x: $(FHEOFFLINE) pairwise-offline.x: $(FHEOFFLINE) cnc-offline.x: $(FHEOFFLINE) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 7d8ddcb49..04a063075 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -20,7 +20,9 @@ void Names::init(int player,int pnb,int my_port,const char* servername) player_no=player; portnum_base=pnb; setup_names(servername, my_port); +#ifndef PPC_COMMUNICATION setup_server(); +#endif } Names::Names(int player, int nplayers, const string& servername, int pnb, @@ -83,7 +85,11 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante for (unsigned int i = 0; i < names.size(); i++) cerr << " " << names[i] << ":" << ports[i] << endl; #endif +#ifdef PPC_COMMUNICATION + cerr << "Ppc communication model setup name finished!" << endl; +#else setup_server(); +#endif } Names::Names(ez::ezOptionParser& opt, int argc, const char** argv, @@ -125,7 +131,11 @@ void Names::setup_names(const char *servername, int my_port) my_port = default_port(player_no); int socket_num; +#ifdef PPC_COMMUNICATION + int pn = portnum_base; +#else int pn = portnum_base - 1; +#endif set_up_client_socket(socket_num, servername, pn); octetStream("P" + to_string(player_no)).Send(socket_num); #ifdef DEBUG_NETWORKING @@ -216,8 +226,13 @@ MultiPlayer::MultiPlayer(const Names& Nms) : PlainPlayer::PlainPlayer(const Names& Nms, const string& id) : MultiPlayer(Nms) { - if (Nms.num_players() > 1) + if (Nms.num_players() > 1){ +#ifdef PPC_COMMUNICATION + setup_sockets(Nms.names, Nms.ports, id); +#else setup_sockets(Nms.names, Nms.ports, id, *Nms.server); +#endif + } } @@ -260,11 +275,23 @@ PlayerBase::~PlayerBase() // Set up nmachines client and server sockets to send data back and fro // A machine is a server between it and player i if i<=my_number // Can also communicate with myself, but only with send_to and receive_from +#ifdef PPC_COMMUNICATION +void PlainPlayer::setup_sockets(const vector& names, + const vector& ports, const string& id_base) +#else void PlainPlayer::setup_sockets(const vector& names, const vector& ports, const string& id_base, ServerSocket& server) +#endif { sockets.resize(nplayers); // Set up the client side +#ifdef PPC_COMMUNICATION + for (int i=0; i& names, } octetStream(pn).Send(sockets[i]); } +#endif send_to_self_socket = sockets[player_no]; +#ifndef PPC_COMMUNICATION // Setting up the server side for (int i=0; i<=player_no; i++) { auto id=id_base+"P"+to_string(i); @@ -295,7 +324,7 @@ void PlainPlayer::setup_sockets(const vector& names, #endif sockets[i] = server.get_connection_socket(id); } - +#endif for (int i = 0; i < nplayers; i++) { // timeout of 5 minutes struct timeval tv; diff --git a/Networking/Player.h b/Networking/Player.h index 668c097a7..e686397ae 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -415,8 +415,13 @@ class MultiPlayer : public Player */ class PlainPlayer : public MultiPlayer { - void setup_sockets(const vector& names, const vector& ports, + #ifdef PPC_COMMUNICATION +void setup_sockets(const vector& names, const vector& ports, + const string& id_base); +#else +void setup_sockets(const vector& names, const vector& ports, const string& id_base, ServerSocket& server); +#endif public: /** diff --git a/Networking/Server.cpp b/Networking/Server.cpp index d9a056dd2..b621f2f8d 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -166,8 +166,13 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, #endif assert(my_num >= 0); assert(my_num < nplayers); +#ifndef PPC_COMMUNICATION Server* server = 0; pthread_t thread; +#endif +#ifdef PPC_COMMUNICATION + N.init(my_num, portnum, my_port, hostname.c_str()); +#else if (my_num == 0) { pthread_create(&thread, 0, Server::start_in_thread, @@ -179,5 +184,6 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, pthread_join(thread, 0); delete server; } +#endif return 0; } diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index d69fd7b8d..8c0351a9d 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -195,7 +195,7 @@ int ServerSocket::get_connection_socket(const string& id) while (clients.find(id) == clients.end()) { - if (data_signal.wait(60) == ETIMEDOUT) + if (data_signal.wait(60000) == ETIMEDOUT) throw runtime_error("No client after one minute"); } diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index f728288ff..58e91724f 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -197,6 +197,7 @@ void OnlineMachine::start_networking() throw runtime_error("cannot set port number when using IP file"); if (nplayers == 0 and opt.isSet("-N")) opt.get("-N")->getInt(nplayers); + cerr << "Ppc communication model setup name started!" << endl; playerNames.init(playerno, pnbase, ipFileName, nplayers); } else { if (not opt.get("-ext-server")->isSet) diff --git a/Programs/Source/ppc_circuit.mpc b/Programs/Source/ppc_circuit.mpc index cacf9961f..8f52c4af4 100644 --- a/Programs/Source/ppc_circuit.mpc +++ b/Programs/Source/ppc_circuit.mpc @@ -10,7 +10,8 @@ source2_record = read_array(SOURCE2, 1, pint) def ppc_main(source0_record, source1_record, source2_record): - num_xor = source0_record[0].bit_xor(source1_record[0]) + # num_xor = source0_record[0].bit_xor(source1_record[0]) + num_xor = source0_record[0] ^ source1_record[0] num_and = source0_record[0].bit_and(source2_record[0]) num_not = source1_record[0].bit_not() From cdb0c0f898f0c79b70d0b101872baeb80bd70ba2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 15 Dec 2021 12:58:54 +1100 Subject: [PATCH 013/221] In-place operations for containers. --- Compiler/types.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/Compiler/types.py b/Compiler/types.py index ec062b5cd..48bb27a1d 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5420,6 +5420,22 @@ def __pow__(self, value): __radd__ = __add__ __rmul__ = __mul__ + def __iadd__(self, other): + self[:] += other.get_vector() + return self + + def __isub__(self, other): + self[:] -= other.get_vector() + return self + + def __imul__(self, other): + self[:] *= other.get_vector() + return self + + def __itruediv__(self, other): + self[:] /= other.get_vector() + return self + def __neg__(self): return -self.get_vector() @@ -5770,6 +5786,22 @@ def iadd(self, other): assert self.sizes == other.sizes self.assign_vector(self.get_vector() + other.get_vector()) + def __iadd__(self, other): + self[:] += other.get_vector() + return self + + def __isub__(self, other): + self[:] -= other.get_vector() + return self + + def __imul__(self, other): + self[:] *= other.get_vector() + return self + + def __itruediv__(self, other): + self[:] /= other.get_vector() + return self + def __mul__(self, other): # legacy function return self.mul(other) From 95c0383e1a24277ab9ffd9bdc54bcbbac00dcd11 Mon Sep 17 00:00:00 2001 From: shareong <740310627@qq.com> Date: Wed, 5 Jan 2022 21:19:40 +0800 Subject: [PATCH 014/221] support proxy pattern --- ExternalIO/Client.hpp | 4 ++-- Machines/OTMachine.cpp | 4 ++-- Makefile | 8 ++++---- Networking/Player.cpp | 4 +++- Networking/sockets.cpp | 2 +- Processor/Machine.hpp | 12 ++++++------ Processor/OfflineMachine.hpp | 2 +- Processor/Online-Thread.hpp | 2 +- Processor/OnlineOptions.cpp | 12 ++++++++++++ Processor/PpcConstant.cpp | 9 +++++++++ Processor/PpcConstant.h | 2 ++ Processor/Processor.hpp | 24 +++++++++++++++++++++++- Scripts/setup-ssl.sh | 2 +- Yao/YaoEvaluator.cpp | 2 +- Yao/YaoGarbler.cpp | 2 +- 15 files changed, 69 insertions(+), 22 deletions(-) diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index 601d9a486..21552791e 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -20,8 +20,8 @@ Client::Client(const vector& hostnames, int port_base, { set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i); octetStream(to_string(my_client_id)).Send(plain_sockets[i]); - sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i], - "P" + to_string(i), "C" + to_string(my_client_id), true); + // sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i], + // "P" + to_string(i), "C" + to_string(my_client_id), true); if (i == 0) specification.Receive(sockets[0]); } diff --git a/Machines/OTMachine.cpp b/Machines/OTMachine.cpp index 351871c9c..1549b7a9a 100644 --- a/Machines/OTMachine.cpp +++ b/Machines/OTMachine.cpp @@ -226,7 +226,7 @@ OTMachine::OTMachine(int argc, const char** argv) N.push_back(new Names(my_num, portnum_base + 1000 * N.size(), names)); } - P = new RealTwoPartyPlayer(*N[0], 1 - my_num, "machine"); + P = new RealTwoPartyPlayer(*N[0], 1 - my_num, "M0"); timeval baseOTstart, baseOTend; gettimeofday(&baseOTstart, NULL); @@ -320,7 +320,7 @@ void OTMachine::run() // now setup resources for each thread // round robin with the names players[i] = new RealTwoPartyPlayer(*N[i % N.size()], 1 - my_num, - "thread" + to_string(i)); + "T" + to_string(i)); tinfos[i].thread_num = i+1; tinfos[i].other_player_num = 1 - my_num; tinfos[i].nOTs = nOTs; diff --git a/Makefile b/Makefile index 67114b631..027893b3c 100644 --- a/Makefile +++ b/Makefile @@ -16,10 +16,10 @@ GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) OT_EXE = ot.x ot-offline.x -CFLAGS += -DVERBOSE_COMM -CFLAGS += -DDEBUG_THREADS -CFLAGS += -DDEBUG_THREAD_QUEUE -CFLAGS += -DDEBUG_NETWORKING +# CFLAGS += -DVERBOSE_COMM +# CFLAGS += -DDEBUG_THREADS +# CFLAGS += -DDEBUG_THREAD_QUEUE +# CFLAGS += -DDEBUG_NETWORKING CFLAGS += -DINSECURE CFLAGS += -DPPC_COMMUNICATION diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 04a063075..031685697 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -287,7 +287,9 @@ void PlainPlayer::setup_sockets(const vector& names, // Set up the client side #ifdef PPC_COMMUNICATION for (int i=0; i::Machine(int my_number, Names& playerNames, // make directory for outputs if necessary mkdir_p(PREP_DIR); - string id = "machine"; + string id = "M0"; if (use_encryption) P = new CryptoPlayer(N, id); else @@ -98,11 +98,11 @@ Machine::Machine(int my_number, Names& playerNames, load_schedule(progname_str); // remove persistence if necessary - for (auto& prog : progs) - { - if (prog.writes_persistance) - ofstream(Binary_File_IO::filename(my_number), ios::out); - } + // for (auto& prog : progs) + // { + // if (prog.writes_persistance) + // ofstream(Binary_File_IO::filename(my_number), ios::out); + // } #ifdef VERBOSE progs[0].print_offline_cost(); diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index b9901a0c6..b94e7f188 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -15,7 +15,7 @@ OfflineMachine::OfflineMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, V, int nplayers) : W(argc, argv, opt, online_opts, V(), nplayers), playerNames( - W::playerNames), P(*this->new_player("machine")) + W::playerNames), P(*this->new_player("M0")) { machine.load_schedule(online_opts.progname, false); Program program(playerNames.num_players()); diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 1ef7da055..8ef264a35 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -49,7 +49,7 @@ void thread_info::Sub_Main_Func() fprintf(stderr, "\tI am in thread %d\n",num); #endif Player* player; - string id = "thread" + to_string(num); + string id = "T" + to_string(num); if (machine.use_encryption) { #ifdef VERBOSE_OPTIONS diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 12b64257c..4cb3c0645 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -209,6 +209,14 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "-DBG", // Flag token. "--debug" // Flag token. ); + opt.add("", // Default. + 1, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "7 bytes job id", // Help description. + "-ID", // Flag token. + "--job-id" // Flag token. + ); opt.add("", 0, // Required? 0, // Number of args expected. @@ -250,6 +258,10 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, bool debug_flag = opt.isSet("--debug"); set_debug_flag(debug_flag); + std::string job_id; + opt.get("-ID")->getString(job_id); + set_job_id(job_id); + direct = opt.isSet("--direct"); opt.get("--bucket-size")->getInt(bucket_size); diff --git a/Processor/PpcConstant.cpp b/Processor/PpcConstant.cpp index 64a1ed665..54bd6b873 100644 --- a/Processor/PpcConstant.cpp +++ b/Processor/PpcConstant.cpp @@ -5,6 +5,7 @@ std::string PPC_PREFIX; int CONNETION_waiting_millisecond_FLAG = 0; bool DEBUG_FLAG = false; std::string PPC_DEBUG_PREFIEX = "PPC-LOG-"; +std::string JOB_ID = ""; std::string get_prefix() { return PPC_PREFIX; @@ -30,3 +31,11 @@ bool get_debug_flag() { void set_debug_flag(bool debug_flag) { DEBUG_FLAG = debug_flag; } + +const std::string& get_job_id() { + return JOB_ID; +} + +void set_job_id(const std::string &id) { + JOB_ID = id; +} \ No newline at end of file diff --git a/Processor/PpcConstant.h b/Processor/PpcConstant.h index c5d9dbc03..d1e874c0d 100644 --- a/Processor/PpcConstant.h +++ b/Processor/PpcConstant.h @@ -17,3 +17,5 @@ void set_connection_waiting_millisecond_flag(int sleep_time); bool get_debug_flag(); void set_debug_flag(bool debug_flag); +const std::string& get_job_id(); +void set_job_id(const std::string &id); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index dd1e82382..18f5e5618 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -374,7 +374,7 @@ void Processor::read_socket_private(int client_id, template void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { string filename; - filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; + filename = "Data-Shares/Transactions-P" + to_string(P.my_num()) + ".data"; unsigned int size = data_registers.size(); @@ -401,6 +401,8 @@ void Processor::read_shares_from_file(int start_file_posn, int end_ // Append share data in data_registers to end of file. Expects Persistence directory to exist. template void Processor::write_shares_to_file(const vector& data_registers) { + ofstream(Binary_File_IO::filename(P.my_num()), ios::out); + string filename = binary_file_io.filename(P.my_num()); unsigned int size = data_registers.size(); @@ -415,6 +417,26 @@ void Processor::write_shares_to_file(const vector& data_regist binary_file_io.write_to_file(filename, inpbuf); } +// template +// void Processor::write_shares_to_file(const vector& data_registers) { +// string dir = "Persistence"; +// mkdir_p(dir.c_str()); + +// string filename; +// filename = dir + "/Transactions-P" + to_string(P.my_num()) + ".data"; + +// unsigned int size = data_registers.size(); + +// vector< sint > inpbuf (size); + +// for (unsigned int i = 0; i < size; i++) +// { +// inpbuf[i] = get_Sp_ref(data_registers[i]); +// } + +// binary_file_io.write_to_file(filename, inpbuf); +// } + template void SubProcessor::POpen(const vector& reg,const Player& P,int size) { diff --git a/Scripts/setup-ssl.sh b/Scripts/setup-ssl.sh index ffd79bf0d..56fd03142 100755 --- a/Scripts/setup-ssl.sh +++ b/Scripts/setup-ssl.sh @@ -10,7 +10,7 @@ test -e Player-Data || mkdir Player-Data echo Setting up SSL for $n parties for i in `seq 0 $[n-1]`; do - openssl req -newkey rsa -nodes -x509 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i" + openssl req -newkey rsa -nodes -x509 -days 3650 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i" done c_rehash Player-Data diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp index 0126c8b15..1553f1729 100644 --- a/Yao/YaoEvaluator.cpp +++ b/Yao/YaoEvaluator.cpp @@ -20,7 +20,7 @@ YaoEvaluator::YaoEvaluator(int thread_num, YaoEvalMaster& master) : Thread>(thread_num, master), YaoCommon(master), master(master), - player(N, 0, "thread" + to_string(thread_num)), + player(N, 0, "T" + to_string(thread_num)), ot_ext(OTExtensionWithMatrix::setup(player, {}, RECEIVER, true)) { set_n_program_threads(master.machine.nthreads); diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index 53b0401f4..3f7f0b6b7 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -23,7 +23,7 @@ YaoGarbler::YaoGarbler(int thread_num, YaoGarbleMaster& master) : master(master), and_proc_timer(CLOCK_PROCESS_CPUTIME_ID), and_main_thread_timer(CLOCK_THREAD_CPUTIME_ID), - player(master.N, 1, "thread" + to_string(thread_num)), + player(master.N, 1, "T" + to_string(thread_num)), ot_ext(OTExtensionWithMatrix::setup(player, master.get_delta().get<__m128i>(), SENDER, true)) { From 8a19fd3aeca2d0827b748a9e987233e35b628d99 Mon Sep 17 00:00:00 2001 From: Shareong <740310627@qq.com> Date: Thu, 6 Jan 2022 16:03:05 +0800 Subject: [PATCH 015/221] detach task data --- Processor/Binary_File_IO.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index be1fb8fdb..9a948731f 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -1,4 +1,5 @@ #include "Processor/Binary_File_IO.h" +#include "Processor/PpcConstant.h" /* * Provides generalised file read and write methods for arrays of shares. @@ -10,7 +11,8 @@ inline string Binary_File_IO::filename(int my_number) { string dir = "Persistence"; mkdir_p(dir.c_str()); - return dir + "/Transactions-P" + to_string(my_number) + ".data"; + // return dir + "/Transactions-P" + to_string(my_number) + ".data"; + return dir + "/" + get_job_id() + "-" + to_string(my_number) + ".data"; } template From 3b356c3df4961a6ed63175d6db1b8fe7575313b0 Mon Sep 17 00:00:00 2001 From: Shareong <740310627@qq.com> Date: Thu, 6 Jan 2022 16:04:42 +0800 Subject: [PATCH 016/221] detach task data --- Processor/Processor.hpp | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 18f5e5618..5da8bdc1d 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -3,6 +3,7 @@ #include "Processor/Processor.h" #include "Processor/Program.h" +#include "Processor/PpcConstant.h" #include "GC/square64.h" #include "Protocols/ReplicatedInput.hpp" @@ -374,7 +375,8 @@ void Processor::read_socket_private(int client_id, template void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { string filename; - filename = "Data-Shares/Transactions-P" + to_string(P.my_num()) + ".data"; + // filename = "Data-Shares/Transactions-P" + to_string(P.my_num()) + ".data"; + filename = "Data-Shares/" + get_job_id() + "-" + to_string(P.my_num()) + ".data"; unsigned int size = data_registers.size(); @@ -417,26 +419,6 @@ void Processor::write_shares_to_file(const vector& data_regist binary_file_io.write_to_file(filename, inpbuf); } -// template -// void Processor::write_shares_to_file(const vector& data_registers) { -// string dir = "Persistence"; -// mkdir_p(dir.c_str()); - -// string filename; -// filename = dir + "/Transactions-P" + to_string(P.my_num()) + ".data"; - -// unsigned int size = data_registers.size(); - -// vector< sint > inpbuf (size); - -// for (unsigned int i = 0; i < size; i++) -// { -// inpbuf[i] = get_Sp_ref(data_registers[i]); -// } - -// binary_file_io.write_to_file(filename, inpbuf); -// } - template void SubProcessor::POpen(const vector& reg,const Player& P,int size) { From e07d9bf2a3231fe95557106371ce25f3da32f5d6 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 11 Jan 2022 16:04:59 +1100 Subject: [PATCH 017/221] Maintenance. --- .gitmodules | 2 +- BMR/Party.cpp | 2 +- BMR/RealGarbleWire.hpp | 2 +- BMR/RealProgramParty.hpp | 8 +- BMR/Register.h | 3 + BMR/TrustedParty.cpp | 6 + BMR/TrustedParty.h | 3 +- CHANGELOG.md | 13 ++ Compiler/GC/types.py | 10 +- Compiler/comparison.py | 11 +- Compiler/compilerLib.py | 2 +- Compiler/exceptions.py | 5 +- Compiler/floatingpoint.py | 14 +- Compiler/instructions.py | 17 +- Compiler/instructions_base.py | 83 ++++++++- Compiler/library.py | 23 ++- Compiler/ml.py | 12 +- Compiler/non_linear.py | 43 +++-- Compiler/program.py | 28 ++- Compiler/types.py | 50 +++-- ECDSA/hm-ecdsa-party.hpp | 4 +- ECDSA/mascot-ecdsa-party.cpp | 2 + ECDSA/ot-ecdsa-party.hpp | 4 +- ECDSA/preprocessing.hpp | 10 +- ECDSA/sign.hpp | 20 +- ExternalIO/Client.h | 25 +++ ExternalIO/README.md | 59 ++---- FHE/FHE_Params.h | 3 - FHE/NTL-Subs.cpp | 7 +- FHE/NTL-Subs.h | 5 +- FHE/NoiseBounds.cpp | 1 - FHE/Ring_Element.cpp | 22 ++- FHE/Ring_Element.h | 1 + FHEOffline/PairwiseGenerator.cpp | 2 +- FHEOffline/SimpleGenerator.h | 2 +- GC/BitAdder.hpp | 2 +- GC/CcdPrep.h | 5 - GC/CcdPrep.hpp | 8 + GC/CcdShare.h | 1 + GC/FakeSecret.h | 3 + GC/Instruction.cpp | 2 +- GC/NoShare.h | 6 +- GC/PostSacriBin.cpp | 16 +- GC/PostSacriBin.h | 5 +- GC/RepPrep.hpp | 5 + GC/Secret.h | 3 + GC/SemiPrep.cpp | 13 +- GC/SemiPrep.h | 4 +- GC/ShareSecret.h | 2 + GC/ShareSecret.hpp | 10 +- GC/ShareThread.h | 3 - GC/ShareThread.hpp | 3 +- GC/Thread.h | 2 - GC/Thread.hpp | 7 - GC/ThreadMaster.hpp | 4 +- GC/TinierSharePrep.h | 2 - GC/TinierSharePrep.hpp | 18 +- GC/TinyPrep.hpp | 2 +- GC/VectorInput.h | 6 + GC/VectorProtocol.h | 7 +- GC/VectorProtocol.hpp | 18 +- GC/instructions.h | 2 +- License.txt | 2 +- Machines/Atlas.hpp | 16 ++ Machines/Rep.hpp | 1 + Machines/Rep4.hpp | 17 ++ Machines/SPDZ.hpp | 12 +- Machines/SPDZ2k.hpp | 11 +- Machines/Semi.hpp | 1 + Machines/Semi2k.hpp | 15 ++ Machines/ShamirMachine.hpp | 1 + Machines/Tinier.cpp | 23 +++ Machines/atlas-party.cpp | 7 +- Machines/emulate.cpp | 8 +- Machines/hemi-party.cpp | 1 + Machines/no-party.cpp | 1 + Machines/soho-party.cpp | 1 + Makefile | 28 +-- Math/BitVec.h | 20 +- Math/Setup.hpp | 5 +- Math/ValueInterface.h | 1 + Math/Z2k.h | 3 + Math/Zp_Data.cpp | 36 ++++ Math/Zp_Data.h | 33 +--- Math/gfp.h | 2 +- Networking/CryptoPlayer.cpp | 5 + Networking/Player.cpp | 49 +++-- Networking/Player.h | 17 +- Networking/Receiver.cpp | 8 + Networking/Sender.cpp | 12 +- Networking/Server.cpp | 46 +++-- Networking/Server.h | 7 +- Networking/ssl_sockets.h | 13 ++ OT/BaseOT.cpp | 18 ++ OT/NPartyTripleGenerator.h | 13 +- Processor/BaseMachine.cpp | 34 +++- Processor/BaseMachine.h | 15 +- Processor/Binary_File_IO.hpp | 2 +- Processor/Data_Files.h | 37 ++-- Processor/Data_Files.hpp | 18 +- Processor/DummyProtocol.h | 4 +- Processor/FieldMachine.h | 5 +- Processor/FieldMachine.hpp | 3 +- Processor/HonestMajorityMachine.cpp | 2 +- Processor/Input.h | 5 +- Processor/Input.hpp | 10 +- Processor/Instruction.hpp | 2 + Processor/Machine.h | 1 - Processor/Machine.hpp | 16 +- Processor/Memory.h | 3 + Processor/NoFilePrep.h | 22 +++ Processor/OfflineMachine.hpp | 6 +- Processor/Online-Thread.hpp | 15 +- Processor/OnlineOptions.cpp | 14 ++ Processor/OnlineOptions.h | 6 + Processor/OnlineOptions.hpp | 30 +++ Processor/PrepBase.cpp | 20 +- Processor/PrepBase.h | 6 +- Processor/Processor.h | 4 - Processor/Processor.hpp | 50 +---- Processor/RingMachine.h | 2 +- Processor/RingMachine.hpp | 3 +- Processor/ThreadQueue.cpp | 21 +++ Processor/ThreadQueue.h | 5 + Processor/TruncPrTuple.h | 37 +++- Programs/Source/keras_mnist_lenet_predict.mpc | 44 +++++ Protocols/Atlas.h | 11 +- Protocols/Atlas.hpp | 9 +- Protocols/Beaver.h | 11 +- Protocols/Beaver.hpp | 30 ++- Protocols/BrainShare.h | 2 + Protocols/FakeProtocol.h | 35 ++-- Protocols/FakeShare.h | 3 + Protocols/Hemi.h | 6 +- Protocols/Hemi.hpp | 22 ++- Protocols/HighGearKeyGen.cpp | 2 +- Protocols/LowGearKeyGen.cpp | 2 +- Protocols/LowGearKeyGen.hpp | 2 +- Protocols/MAC_Check.hpp | 2 + Protocols/MalRepRingPrep.hpp | 41 ----- Protocols/MaliciousRep3Share.h | 1 + Protocols/MaliciousRepPO.h | 8 +- Protocols/MaliciousRepPO.hpp | 18 +- Protocols/MaliciousRepPrep.hpp | 5 +- Protocols/MamaPrep.hpp | 1 + Protocols/MascotPrep.h | 2 - Protocols/MascotPrep.hpp | 12 +- Protocols/NoProtocol.h | 4 +- Protocols/PostSacriRepRingShare.h | 2 + Protocols/PostSacrifice.h | 4 +- Protocols/PostSacrifice.hpp | 7 +- Protocols/ProtocolSet.h | 107 +++++++++++ Protocols/ProtocolSetup.h | 95 ++++++++++ Protocols/Rep3Share.h | 27 +++ Protocols/Rep3Share2k.h | 12 -- Protocols/Rep4.h | 12 +- Protocols/Rep4.hpp | 34 ++-- Protocols/Rep4Prep.hpp | 2 +- Protocols/Replicated.h | 20 +- Protocols/Replicated.hpp | 174 +++++++----------- Protocols/ReplicatedInput.h | 3 +- Protocols/ReplicatedInput.hpp | 2 +- Protocols/ReplicatedPO.h | 24 +++ Protocols/ReplicatedPO.hpp | 21 +++ Protocols/ReplicatedPrep.h | 25 ++- Protocols/ReplicatedPrep.hpp | 150 +++++++++------ Protocols/{Semi2k.h => Semi.h} | 23 ++- Protocols/Semi2kShare.h | 6 +- Protocols/SemiShare.h | 8 +- Protocols/Shamir.h | 16 +- Protocols/Shamir.hpp | 14 +- Protocols/ShuffleSacrifice.hpp | 16 +- Protocols/Spdz2kPrep.h | 3 - Protocols/Spdz2kPrep.hpp | 35 ++-- Protocols/SpdzWise.h | 11 +- Protocols/SpdzWise.hpp | 28 +-- Protocols/SpdzWiseInput.hpp | 3 +- Protocols/SpdzWisePrep.hpp | 13 +- Protocols/SpdzWiseRing.hpp | 2 +- Protocols/SquarePrep.h | 6 +- README.md | 6 +- Scripts/decompile.py | 16 ++ Scripts/memory-usage.py | 29 +++ Scripts/run-common.sh | 31 +--- Scripts/test_streaming.sh | 4 + Scripts/tldr.sh | 3 +- Tools/BitVector.cpp | 9 + Tools/BitVector.h | 9 +- Tools/Buffer.cpp | 13 +- Tools/Bundle.h | 2 +- Tools/TimerWithComm.cpp | 23 +++ Tools/TimerWithComm.h | 23 +++ Tools/benchmarking.cpp | 15 ++ Tools/benchmarking.h | 3 + Tools/octetStream.h | 4 +- Tools/random.cpp | 4 +- Tools/random.h | 2 + Utils/Fake-Offline.cpp | 2 +- Utils/binary-example.cpp | 140 ++++++++++++++ Utils/mixed-example.cpp | 137 ++++++++++++++ Utils/paper-example.cpp | 49 ++--- Utils/stream-fake-mascot-triples.cpp | 21 ++- Yao/YaoEvaluator.h | 3 - Yao/YaoGarbler.cpp | 5 - Yao/YaoGarbler.h | 2 - Yao/YaoWire.h | 4 + Yao/YaoWire.hpp | 20 ++ doc/Doxyfile | 2 +- doc/conf.py | 5 +- doc/index.rst | 14 +- doc/io.rst | 10 + doc/low-level.rst | 142 +++++--------- doc/networking.rst | 6 +- doc/non-linear.rst | 4 +- doc/preprocessing.rst | 64 ++++++- doc/troubleshooting.rst | 21 ++- 216 files changed, 2406 insertions(+), 1113 deletions(-) create mode 100644 Machines/Atlas.hpp create mode 100644 Machines/Rep4.hpp create mode 100644 Machines/Semi2k.hpp create mode 100644 Machines/Tinier.cpp create mode 100644 Processor/NoFilePrep.h create mode 100644 Processor/OnlineOptions.hpp create mode 100644 Programs/Source/keras_mnist_lenet_predict.mpc create mode 100644 Protocols/ProtocolSet.h create mode 100644 Protocols/ProtocolSetup.h create mode 100644 Protocols/ReplicatedPO.h create mode 100644 Protocols/ReplicatedPO.hpp rename Protocols/{Semi2k.h => Semi.h} (75%) create mode 100755 Scripts/decompile.py create mode 100755 Scripts/memory-usage.py create mode 100644 Tools/TimerWithComm.cpp create mode 100644 Tools/TimerWithComm.h create mode 100644 Tools/benchmarking.cpp create mode 100644 Utils/binary-example.cpp create mode 100644 Utils/mixed-example.cpp diff --git a/.gitmodules b/.gitmodules index 455a55143..32dca28be 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,7 @@ url = https://github.com/mkskeller/SimpleOT [submodule "mpir"] path = mpir - url = git://github.com/wbhart/mpir.git + url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 5ca1360ab..84ba909b3 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -259,7 +259,7 @@ ProgramParty::~ProgramParty() reset(); if (P) { - cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl; + cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl; delete P; } delete[] eval_threads; diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 55adcbfb1..760a20b89 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -175,7 +175,7 @@ void GarbleInputter::exchange() assert(party.P != 0); assert(party.MC != 0); auto& protocol = party.shared_proc->protocol; - protocol.init_mul(party.shared_proc); + protocol.init_mul(); for (auto& tuple : tuples) protocol.prepare_mul(tuple.first->mask, T::constant(1, party.P->my_num(), party.mac_key) diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 0c97f9bd8..8e16c3077 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -155,7 +155,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : while (next != GC::DONE_BREAK); MC->Check(*P); - data_sent = P->comm_stats.total_data() + prep->data_sent(); + data_sent = P->total_comm().sent; this->machine.write_memory(this->N.my_num()); } @@ -173,7 +173,8 @@ void RealProgramParty::garble() garble_jobs.clear(); garble_inputter->reset_all(*P); auto& protocol = *garble_protocol; - protocol.init_mul(shared_proc); + protocol.init(*prep, shared_proc->MC); + protocol.init_mul(); next = this->first_phase(program, garble_processor, this->garble_machine); @@ -181,7 +182,8 @@ void RealProgramParty::garble() protocol.exchange(); typename T::Protocol second_protocol(*P); - second_protocol.init_mul(shared_proc); + second_protocol.init(*prep, shared_proc->MC); + second_protocol.init_mul(); for (auto& job : garble_jobs) job.middle_round(*this, second_protocol); diff --git a/BMR/Register.h b/BMR/Register.h index d0a75e930..f348f7b7e 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -293,6 +293,9 @@ class ProgramRegister : public Phase, public Register template static void convcbit2s(GC::Processor&, const BaseInstruction&) { throw runtime_error("convcbit2s not implemented"); } + template + static void andm(GC::Processor&, const BaseInstruction&) + { throw runtime_error("andm not implemented"); } // most BMR phases don't need actual input template diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index 6bd1ba264..439bcfc73 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty() _received_gc_received = 0; n_received = 0; randomfd = open("/dev/urandom", O_RDONLY); + done_filling = false; +} + +BaseTrustedParty::~BaseTrustedParty() +{ + close(randomfd); } TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h index 24e8120de..260e7a516 100644 --- a/BMR/TrustedParty.h +++ b/BMR/TrustedParty.h @@ -20,7 +20,7 @@ class BaseTrustedParty : virtual public CommonFakeParty { vector msg_input_masks; BaseTrustedParty(); - virtual ~BaseTrustedParty() {} + virtual ~BaseTrustedParty(); /* From NodeUpdatable class */ virtual void NodeReady(); @@ -104,7 +104,6 @@ class TrustedProgramParty : public BaseTrustedParty { void add_all_keys(const Register& reg, bool external); }; - inline void BaseTrustedParty::add_keys(const Register& reg) { for(int p = 0; p < get_n_parties(); p++) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c9be9e5b..2b75d24f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.9 (Jan 11, 2021) + +- Disassembler +- Run-time parameter for probabilistic truncation error +- Probabilistic truncation for some protocols computing modulo a prime +- Simplified C++ interface +- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757) +- More general scalar-vector multiplication +- Complete memory support for clear bits +- Extended clear bit functionality with Yao's garbled circuits +- Allow preprocessing information to be supplied via named pipes +- In-place operations for containers + ## 0.2.8 (Nov 4, 2021) - Tested on Apple laptop with ARM chip diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 5a65e73ac..53da15ba2 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -112,10 +112,16 @@ def load_mem(cls, address, mem_type=None, size=None): return cls.load_dynamic_mem(address) else: for i in range(res.size): - cls.load_inst[util.is_constant(address)](res[i], address + i) + cls.mem_op(cls.load_inst, res[i], address + i) return res def store_in_mem(self, address): - self.store_inst[isinstance(address, int)](self, address) + self.mem_op(self.store_inst, self, address) + @staticmethod + def mem_op(inst, reg, address): + direct = isinstance(address, int) + if not direct: + address = regint.conv(address) + inst[direct](reg, address) @classmethod def new(cls, value=None, n=None): if util.is_constant(value): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index f4cf89ad6..2f7ca81f5 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -77,13 +77,16 @@ def LTZ(s, a, k, kappa): k: bit length of a """ + movs(s, program.non_linear.ltz(a, k, kappa)) + +def LtzRing(a, k): from .types import sint, _bitint from .GC.types import sbitvec if program.use_split(): summands = a.split_to_two_summands(k) carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] - movs(s, sint.conv(msb)) + return sint.conv(msb) return elif program.options.ring: from . import floatingpoint @@ -96,11 +99,7 @@ def LTZ(s, a, k, kappa): a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) - movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))) - return - t = sint() - Trunc(t, a, k, k - 1, kappa, True) - subsfi(s, t, 0) + return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)) def LessThanZero(a, k, kappa): from . import types diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 64e76434c..b2898e21a 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -82,7 +82,7 @@ def run(args, options): prog.finalize() if prog.req_num: - print('Program requires:') + print('Program requires at most:') for x in prog.req_num.pretty(): print(x) diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py index fd0265637..c68ecd317 100644 --- a/Compiler/exceptions.py +++ b/Compiler/exceptions.py @@ -12,4 +12,7 @@ class ArgumentError(CompilerError): """ Exception raised for errors in instruction argument parsing. """ def __init__(self, arg, msg): self.arg = arg - self.msg = msg \ No newline at end of file + self.msg = msg + +class VectorMismatch(CompilerError): + pass diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index a15a62dd9..c596240b6 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -392,7 +392,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) - lts(d, c_dprime, r_prime, l, kappa) + d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa) if compute_modulo: b = c_dprime - r_prime + pow2m * d return b, pow2m @@ -629,12 +629,14 @@ def BITLT(a, b, bit_length): # - From the paper # Multiparty Computation for Interval, Equality, and Comparison without # Bit-Decomposition Protocol -def BitDecFull(a, maybe_mixed=False): +def BitDecFull(a, n_bits=None, maybe_mixed=False): from .library import get_program, do_while, if_, break_point from .types import sint, regint, longint, cint p = get_program().prime assert p bit_length = p.bit_length() + n_bits = n_bits or bit_length + assert n_bits <= bit_length logp = int(round(math.log(p, 2))) if abs(p - 2 ** logp) / p < 2 ** -get_program().security: # inspired by Rabbit (https://eprint.iacr.org/2021/119) @@ -677,12 +679,12 @@ def _(): czero = (c==0) q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t) fbar = [bbits[0].clear_type.conv(cint(x)) - for x in ((1< 0. + """ Crash runtime if the value in the register is not zero. :param: Crash condition (regint)""" code = base.opcodes['CRASH'] @@ -1275,7 +1275,7 @@ class prep(base.Instruction): field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, self.args[0]), 1) + req_node.increment((self.field_type, self.args[0]), self.get_size()) def has_var_args(self): return True @@ -2407,19 +2407,6 @@ def expand(self): subml(self.args[0], s[5], c[1]) -@base.gf2n -@base.vectorize -class lts(base.CISC): - """ Secret comparison $s_i = (s_j < s_k)$. """ - __slots__ = [] - arg_format = ['sw', 's', 's', 'int', 'int'] - - def expand(self): - from .types import sint - a = sint() - subs(a, self.args[1], self.args[2]) - comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) - # placeholder for documentation class cisc: """ Meta instruction for emulation. This instruction is only generated diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 38fd97d29..fb2a67b89 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -4,6 +4,8 @@ import inspect import functools import copy +import sys +import struct from Compiler.exceptions import * from Compiler.config import * from Compiler import util @@ -299,11 +301,12 @@ def maybe_vectorized_instruction(*args, **kwargs): vectorized_name = 'v' + instruction.__name__ Vectorized_Instruction.__name__ = vectorized_name global_dict[vectorized_name] = Vectorized_Instruction + + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction - instruction.__doc__ = '' - # exclude GF(2^n) instructions from documentation - if instruction.code and instruction.code >> 8 == 1: - maybe_vectorized_instruction.__doc__ = '' + maybe_vectorized_instruction.arg_format = instruction.arg_format return maybe_vectorized_instruction @@ -389,8 +392,11 @@ def maybe_gf2n_instruction(*args, **kwargs): else: global_dict[GF2N_Instruction.__name__] = GF2N_Instruction + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction_cls - instruction_cls.__doc__ = '' + maybe_gf2n_instruction.arg_format = instruction.arg_format return maybe_gf2n_instruction #return instruction @@ -661,6 +667,12 @@ def encode(cls, arg): assert arg.i >= 0 return int_to_bytes(arg.i) + def __init__(self, f): + self.i = struct.unpack('>I', f.read(4))[0] + + def __str__(self): + return self.reg_type + str(self.i) + class ClearModpAF(RegisterArgFormat): reg_type = RegType.ClearModp @@ -686,6 +698,12 @@ def check(cls, arg): def encode(cls, arg): return int_to_bytes(arg) + def __init__(self, f): + self.i = struct.unpack('>i', f.read(4))[0] + + def __str__(self): + return str(self.i) + class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): @@ -722,6 +740,13 @@ def check(cls, arg): def encode(cls, arg): return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg)) + def __init__(self, f): + tmp = f.read(16) + self.str = str(tmp[0:tmp.find(b'\0')], 'ascii') + + def __str__(self): + return self.str + ArgFormats = { 'c': ClearModpAF, 's': SecretModpAF, @@ -890,6 +915,54 @@ def __str__(self): def __repr__(self): return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' +class ParsedInstruction: + reverse_opcodes = {} + + def __init__(self, f): + cls = type(self) + from Compiler import instructions + from Compiler.GC import instructions as gc_inst + if not cls.reverse_opcodes: + for module in instructions, gc_inst: + for x, y in inspect.getmodule(module).__dict__.items(): + if inspect.isclass(y) and y.__name__[0] != 'v': + try: + cls.reverse_opcodes[y.code] = y + except AttributeError: + pass + read = lambda: struct.unpack('>I', f.read(4))[0] + full_code = read() + code = full_code % (1 << Instruction.code_length) + self.size = full_code >> Instruction.code_length + self.type = cls.reverse_opcodes[code] + t = self.type + name = t.__name__ + try: + n_args = len(t.arg_format) + self.var_args = False + except: + n_args = read() + self.var_args = True + try: + arg_format = iter(t.arg_format) + except: + if name == 'cisc': + arg_format = itertools.chain(['str'], itertools.repeat('int')) + else: + arg_format = itertools.repeat('int') + self.args = [ArgFormats[next(arg_format)](f) + for i in range(n_args)] + + def __str__(self): + name = self.type.__name__ + res = name + ' ' + if self.size > 1: + res = 'v' + res + str(self.size) + ', ' + if self.var_args: + res += str(len(self.args)) + ', ' + res += ', '.join(str(arg) for arg in self.args) + return res + class VarArgsInstruction(Instruction): def has_var_args(self): return True diff --git a/Compiler/library.py b/Compiler/library.py index 529608dc2..7bab1951a 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -219,6 +219,9 @@ def crash(condition=None): :param condition: crash if true (default: true) """ + if isinstance(condition, localint): + # allow crash on local values + condition = condition._v if condition == None: condition = regint(1) instructions.crash(regint.conv(condition)) @@ -1347,6 +1350,8 @@ def while_loop(loop_body, condition, arg, g=None): arg = regint(arg) def loop_fn(): result = loop_body(arg) + if isinstance(result, MemValue): + result = result.read() result.link(arg) cont = condition(result) return cont @@ -1531,6 +1536,8 @@ def decorator(body): def if_e(condition): """ Conditional execution with else block. + Use :py:class:`~Compiler.types.MemValue` to assign values that + live beyond. :param condition: regint/cint/int @@ -1538,12 +1545,13 @@ def if_e(condition): .. code:: + y = MemValue(0) @if_e(x > 0) def _(): - ... + y.write(1) @else_ def _(): - ... + y.write(0) """ try: condition = bool(condition) @@ -1647,11 +1655,18 @@ def get_player_id(): return res def listen_for_clients(port): - """ Listen for clients on specific port. """ + """ Listen for clients on specific port base. + + :param port: port base (int/regint/cint) + """ instructions.listen(regint.conv(port)) def accept_client_connection(port): - """ Listen for clients on specific port. """ + """ Accept client connection on specific port base. + + :param port: port base (int/regint/cint) + :returns: client id + """ res = regint() instructions.acceptclientconnection(res, regint.conv(port)) return res diff --git a/Compiler/ml.py b/Compiler/ml.py index 5ff1a3753..7e53a78f8 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1810,6 +1810,7 @@ def __init__(self, report_loss=None): self.print_loss_reduction = False self.i_epoch = MemValue(0) self.stopped_on_loss = MemValue(0) + self.stopped_on_low_loss = MemValue(0) @property def layers(self): @@ -1932,6 +1933,7 @@ def run(self, batch_size=None, stop_on_loss=0): """ Run training. :param batch_size: batch size (defaults to example size of first layer) + :param stop_on_loss: stop when loss falls below this (default: 0) """ if self.n_epochs == 0: return @@ -2013,6 +2015,7 @@ def _(j): if self.tol > 0: res *= (1 - (loss_sum >= 0) * \ (loss_sum < self.tol * n_per_epoch)).reveal() + self.stopped_on_low_loss.write(1 - res) return res def reveal_correctness(self, data, truth, batch_size): @@ -2138,6 +2141,7 @@ def _(): if depreciation: self.gamma.imul(depreciation) print_ln('reducing learning rate to %s', self.gamma) + return 1 - self.stopped_on_low_loss if 'model_output' in program.args: self.output_weights() @@ -2386,6 +2390,7 @@ def trainable_variables(self): return list(self.opt.thetas) def build(self, input_shape, batch_size=128): + data_input_shape = input_shape if self.opt != None and \ input_shape == self.opt.layers[0].X.sizes and \ batch_size <= self.batch_size and \ @@ -2458,9 +2463,10 @@ def build(self, input_shape, batch_size=128): else: raise Exception(layer[0] + ' not supported') if layers[-1].d_out == 1: - layers.append(Output(input_shape[0])) + layers.append(Output(data_input_shape[0])) else: - layers.append(MultiOutput(input_shape[0], layers[-1].d_out)) + layers.append( + MultiOutput(data_input_shape[0], layers[-1].d_out)) if self.optimizer[1]: raise Exception('use keyword arguments for optimizer') opt = self.optimizer[0] @@ -2504,7 +2510,7 @@ def fit(self, x, y, batch_size, epochs=1, validation_data=None): if x.total_size() != self.opt.layers[0].X.total_size(): raise Exception('sample data size mismatch') if y.total_size() != self.opt.layers[-1].Y.total_size(): - print (y, layers[-1].Y) + print (y, self.opt.layers[-1].Y) raise Exception('label size mismatch') if validation_data == None: validation_data = None, None diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 43e10c2e6..01cb4db58 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -1,7 +1,7 @@ from .comparison import * from .floatingpoint import * from .types import * -from . import comparison +from . import comparison, program class NonLinear: kappa = None @@ -30,6 +30,15 @@ def mod2m(self, a, k, m, signed): def trunc_pr(self, a, k, m, signed=True): if isinstance(a, types.cint): return shift_two(a, m) + prog = program.Program.prog + if prog.use_trunc_pr: + if signed and prog.use_trunc_pr != -1: + a += (1 << (k - 1)) + res = sint() + trunc_pr(res, a, k, m) + if signed and prog.use_trunc_pr != -1: + res -= (1 << (k - m - 1)) + return res return self._trunc_pr(a, k, m, signed) def trunc_round_nearest(self, a, k, m, signed): @@ -44,6 +53,9 @@ def trunc(self, a, k, m, kappa, signed): return a return self._trunc(a, k, m, signed) + def ltz(self, a, k, kappa=None): + return -self.trunc(a, k, k - 1, kappa, True) + class Masking(NonLinear): def eqz(self, a, k): c, r = self._mask(a, k) @@ -100,42 +112,44 @@ def __init__(self, prime): def _mod2m(self, a, k, m, signed): if signed: a += cint(1) << (k - 1) - return sint.bit_compose(self.bit_dec(a, k, k, True)[:m]) + return sint.bit_compose(self.bit_dec(a, k, m, True)) def _trunc_pr(self, a, k, m, signed): # nearest truncation return self.trunc_round_nearest(a, k, m, signed) def _trunc(self, a, k, m, signed=None): - if signed: - a += cint(1) << (k - 1) - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) - if signed: - res -= cint(1) << (k - 1 - m) - return res + return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed) def trunc_round_nearest(self, a, k, m, signed): a += cint(1) << (m - 1) if signed: a += cint(1) << (k - 1) k += 1 - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) + res = self._trunc(a, k, m, False) if signed: res -= cint(1) << (k - m - 2) return res def bit_dec(self, a, k, m, maybe_mixed=False): assert k < self.prime.bit_length() - bits = BitDecFull(a, maybe_mixed=maybe_mixed) - if len(bits) < m: - raise CompilerError('%d has fewer than %d bits' % (self.prime, m)) - return bits[:m] + bits = BitDecFull(a, m, maybe_mixed=maybe_mixed) + assert len(bits) == m + return bits def eqz(self, a, k): # always signed a += two_power(k) return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True))) + def ltz(self, a, k, kappa=None): + if k + 1 < self.prime.bit_length(): + # https://dl.acm.org/doi/10.1145/3474123.3486757 + # "negative" values wrap around when doubling, thus becoming odd + return self.mod2m(2 * a, k + 1, 1, False) + else: + return super(KnownPrime, self).ltz(a, k, kappa) + class Ring(Masking): """ Non-linear functionality modulo a power of two known at compile time. """ @@ -172,3 +186,6 @@ def trunc_round_nearest(self, a, k, m, signed): return TruncRing(None, tmp + 1, k - m + 1, 1, signed) else: return super(Ring, self).trunc_round_nearest(a, k, m, signed) + + def ltz(self, a, k, kappa=None): + return LtzRing(a, k) diff --git a/Compiler/program.py b/Compiler/program.py index 19ce52480..5dad8e516 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -578,6 +578,15 @@ def disable_memory_warnings(self): self.warn_about_mem.append(False) self.curr_block.warn_about_mem = False + @staticmethod + def read_tapes(schedule): + if not os.path.exists(schedule): + schedule = 'Programs/Schedules/%s.sch' % schedule + + lines = open(schedule).readlines() + for tapename in lines[2].split(' '): + yield tapename.strip() + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program): @@ -1109,7 +1118,20 @@ def require_bit_length(self, bit_length, t='p'): else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) - class Register(object): + @staticmethod + def read_instructions(tapename): + tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb') + while tape.peek(): + yield inst_base.ParsedInstruction(tape) + + class _no_truth(object): + __slots__ = [] + + def __bool__(self): + raise CompilerError('Cannot derive truth value from register, ' + "consider using 'compile.py -l'") + + class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. @@ -1233,10 +1255,6 @@ def is_clear(self): self.reg_type == RegType.ClearGF2N or \ self.reg_type == RegType.ClearInt - def __bool__(self): - raise CompilerError('Cannot derive truth value from register, ' - "consider using 'compile.py -l'") - def __str__(self): return self.reg_type + str(self.i) diff --git a/Compiler/types.py b/Compiler/types.py index 48bb27a1d..33df2e373 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,7 +127,7 @@ def vectorized_operation(self, *args, **kwargs): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise CompilerError('Different vector sizes of operands: %d/%d' + raise VectorMismatch('Different vector sizes of operands: %d/%d' % (self.size, args[0].size)) set_global_vector_size(self.size) try: @@ -221,7 +221,7 @@ def inputmixed(*args): else: instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),))) -class _number(object): +class _number(Tape._no_truth): """ Number functionality. """ def square(self): @@ -246,7 +246,11 @@ def __mul__(self, other): elif is_one(other): return self else: - return self.mul(other) + try: + return self.mul(other) + except VectorMismatch: + # try reverse multiplication + return NotImplemented __radd__ = __add__ __rmul__ = __mul__ @@ -320,7 +324,7 @@ def __abs__(self): def popcnt_bits(bits): return sum(bits) -class _int(object): +class _int(Tape._no_truth): """ Integer functionality. """ @staticmethod @@ -408,7 +412,7 @@ def half_adder(self, other): def long_one(): return 1 -class _bit(object): +class _bit(Tape._no_truth): """ Binary functionality. """ def bit_xor(self, other): @@ -474,7 +478,7 @@ def bit_xor(self, other): def bit_not(self): return self ^ 1 -class _structure(object): +class _structure(Tape._no_truth): """ Interface for type-dependent container types. """ MemValue = classmethod(lambda cls, value: MemValue(cls.conv(value))) @@ -591,7 +595,7 @@ def traverse(content, level): res.input_from(player) return res -class _vec(object): +class _vec(Tape._no_truth): def link(self, other): assert len(self.v) == len(other.v) for x, y in zip(self.v, other.v): @@ -726,7 +730,7 @@ def expand_to_vector(self, size=None): assert self.size == 1 res = type(self)(size=size) for i in range(size): - movs(res[i], self) + self.mov(res[i], self) return res class _clear(_register): @@ -1010,9 +1014,10 @@ def less_than(self, other, bit_length): if bit_length <= 64: return regint(self) < regint(other) else: + sint.require_bit_length(bit_length + 1) diff = self - other - diff += (1 << (bit_length - 1)) - shifted = diff >> (bit_length - 1) + diff += 1 << bit_length + shifted = diff >> bit_length res = 1 - regint(shifted & 1) return res @@ -1646,7 +1651,7 @@ def binary_output(self, player=None): player = -1 intoutput(player, self) -class localint(object): +class localint(Tape._no_truth): """ Local integer that must prevented from leaking into the secure computation. Uses regint internally. @@ -1669,7 +1674,7 @@ def output(self): __eq__ = lambda self, other: localint(self._v == other) __ne__ = lambda self, other: localint(self._v != other) -class personal(object): +class personal(Tape._no_truth): def __init__(self, player, value): assert value is not NotImplemented assert not isinstance(value, _secret) @@ -2003,9 +2008,11 @@ def mul(self, other): size or one size 1 for a value-vector multiplication. :param other: any compatible type """ - if isinstance(other, _secret) and (1 in (self.size, other.size)) \ + if isinstance(other, _register) and (1 in (self.size, other.size)) \ and (self.size, other.size) != (1, 1): x, y = (other, self) if self.size < other.size else (self, other) + if not isinstance(other, _secret): + return y.expand_to_vector(x.size) * x res = type(self)(size=x.size) mulrs(res, x, y) return res @@ -2221,11 +2228,13 @@ def get_raw_input_from(cls, player): @vectorized_classmethod def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ Securely obtain shares of values input by a client. + This uses the triple-based input protocol introduced by + `DamgÃ¥rd et al. `_ :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) - + :returns: list of sint """ # send shares of a triple to client triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) @@ -2910,7 +2919,7 @@ def bit_decompose_embedding(self): sint.bit_type = sintbit sgf2n.bit_type = sgf2n -class _bitint(object): +class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False @@ -3521,6 +3530,7 @@ def from_int(cls, other): @classmethod def _new(cls, other, k=None, f=None): + assert not isinstance(other, (list, tuple)) res = cls(k=k, f=f) res.v = cint.conv(other) return res @@ -3567,6 +3577,8 @@ def __len__(self): return len(self.v) def __getitem__(self, index): + if isinstance(index, slice): + return [self._new(x, k=self.k, f=self.f) for x in self.v[index]] return self._new(self.v[index], k=self.k, f=self.f) @vectorize @@ -3608,7 +3620,6 @@ def add(self, other): else: return NotImplemented - @vectorize def mul(self, other): """ Clear fixed-point multiplication. @@ -4045,7 +4056,8 @@ def set_precision_from_args(cls, program, adapt_ring=False): 'for fixed-point computation') cls.round_nearest = True if adapt_ring and program.options.ring \ - and 'fix_ring' not in program.args: + and 'fix_ring' not in program.args \ + and 2 * cls.k > int(program.options.ring): need = 2 ** int(math.ceil(math.log(2 * cls.k, 2))) if need != int(program.options.ring): print('Changing computation modulus to 2^%d' % need) @@ -4489,7 +4501,7 @@ def for_mux(self, other): def __neg__(self): return self._new(-self.v + 2 * util.expand(self.Z, self.v.size)) -class _unreduced_squant(object): +class _unreduced_squant(Tape._no_truth): def __init__(self, v, params, res_params=None, n_summands=1): self.v = v self.params = params @@ -5011,7 +5023,7 @@ def reveal(self): :return: cfloat """ return cfloat(self.v.reveal(), self.p.reveal(), self.z.reveal(), self.s.reveal()) -class cfloat(object): +class cfloat(Tape._no_truth): """ Helper class for printing revealed sfloats. """ __slots__ = ['v', 'p', 'z', 's', 'nan'] diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index a68f8e833..fc19e989b 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -52,10 +52,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); pShare sk = typename T::Honest::Protocol(P).get_random(); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; DataPositions usage; diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp index 87573593b..920397cef 100644 --- a/ECDSA/mascot-ecdsa-party.cpp +++ b/ECDSA/mascot-ecdsa-party.cpp @@ -5,6 +5,8 @@ #define NO_MIXED_CIRCUITS +#define NO_SECURITY_CHECK + #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 58f35d4b0..569aa791f 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -113,10 +113,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); sk_prep.get_two(DATA_INVERSE, sk, __); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; typename pShare::TriplePrep prep(0, usage); diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 334d5d1ba..0a5e0ab9c 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -41,8 +41,8 @@ void preprocessing(vector>& tuples, int buffer_size, timer.start(); Player& P = proc.P; auto& prep = proc.DataF; - size_t start = P.sent + prep.data_sent(); - auto stats = P.comm_stats + prep.comm_stats(); + size_t start = P.total_comm().sent; + auto stats = P.total_comm(); auto& extra_player = P; auto& protocol = proc.protocol; @@ -77,7 +77,7 @@ void preprocessing(vector>& tuples, int buffer_size, MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < buffer_size; i++) protocol.prepare_mul(inv_ks[i], sk); protocol.start_exchange(); @@ -106,9 +106,9 @@ void preprocessing(vector>& tuples, int buffer_size, timer.stop(); cout << "Generated " << buffer_size << " tuples in " << timer.elapsed() << " seconds, throughput " << buffer_size / timer.elapsed() << ", " - << 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size + << 1e-3 * (P.total_comm().sent - start) / buffer_size << " kbytes per tuple" << endl; - (P.comm_stats + prep.comm_stats() - stats).print(true); + (P.total_comm() - stats).print(true); } template class T> diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index 10991276a..5686349e2 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length, (void) pk; Timer timer; timer.start(); - size_t start = P.sent; - auto stats = P.comm_stats; + auto stats = P.total_comm(); EcSignature signature; vector opened_R; if (opts.R_after_msg) @@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length, auto& protocol = proc->protocol; if (proc) { - protocol.init_mul(proc); + protocol.init_mul(); protocol.prepare_mul(sk, tuple.a); protocol.start_exchange(); } @@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length, auto rx = tuple.R.x(); signature.s = MC.open( tuple.a * hash_to_scalar(message, length) + prod * rx, P); + auto diff = (P.total_comm() - stats); cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending " - << (P.sent - start) << " bytes" << endl; - auto diff = (P.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(true); return signature; } @@ -139,11 +138,11 @@ void sign_benchmark(vector>& tuples, T sk, P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); P256Element pk = MCc.open(sk, P); MCc.Check(P); cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); for (size_t i = 0; i < min(10lu, tuples.size()); i++) { @@ -154,13 +153,12 @@ void sign_benchmark(vector>& tuples, T sk, Timer timer; timer.start(); auto& check_player = MCp.get_check_player(P); - auto stats = check_player.comm_stats; - auto start = check_player.sent; + auto stats = check_player.total_comm(); MCp.Check(P); MCc.Check(P); + auto diff = (check_player.total_comm() - stats); cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending " - << (check_player.sent - start) << " bytes" << endl; - auto diff = (check_player.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(); } } diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 12ba1c938..5f8e76fd3 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -8,6 +8,9 @@ #include "Networking/ssl_sockets.h" +/** + * Client-side interface + */ class Client { vector plain_sockets; @@ -15,15 +18,37 @@ class Client ssl_service io_service; public: + /** + * Sockets for cleartext communication + */ vector sockets; + + /** + * Specification of computation domain + */ octetStream specification; + /** + * Start a new set of connections to computing parties. + * @param hostnames location of computing parties + * @param port_base port base + * @param my_client_id client identifier + */ Client(const vector& hostnames, int port_base, int my_client_id); ~Client(); + /** + * Securely input private values. + * @param values vector of integer-like values + */ template void send_private_inputs(const vector& values); + /** + * Securely receive output values. + * @param n number of values + * @returns vector of integer-like values + */ template vector receive_outputs(int n); }; diff --git a/ExternalIO/README.md b/ExternalIO/README.md index 36649f5c7..d4f99288b 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -19,6 +19,8 @@ Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 ``` +`` can be any arithmetic protocol (e.g., `mascot`) but not a +binary protocol (e.g., `yao`). This should output that the winning id is 1. Note that the ids have to be incremental, and the client with the highest id has to input 1 as the last argument while the others have to input 0 there. Furthermore, @@ -32,54 +34,21 @@ different hosts, you will have to distribute the `*.pem` files. ### Connection Setup -**listen**(*int port_num*) - -Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background. - -*port_num* - the port number to listen on. - -**acceptclientconnection**(*regint client_socket_id*, *int port_num*) - -Picks the first available client socket connection. Blocks if none available. - -*client_socket_id* - an identifier used to refer to the client socket. - -*port_num* - the port number identifies the socket server to accept connections on. +1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients) +2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection) +3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection) ### Data Exchange -Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py). - -*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*) - -Read a share of an input from a client, blocking on the client send. - -*client_socket_id* - an identifier used to refer to the client socket. - -*number_of_inputs* - the number of inputs expected - -*[inputs]* - returned list of shares of private input. - -**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*) - -Write shares of values including macs to an external client. - -*client_socket_id* - an identifier used to refer to the client socket. - -*[values]* - list of shares of values to send to client. - -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. - -See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message. - -*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*) - -Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf) - -*number_of_inputs* - the number of inputs expected +Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types). -*client_socket_id* - an identifier used to refer to the client socket. +1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket) +2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client) +3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients) -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. +## Client-Side Interface -*[inputs]* - returned list of shares of private input. +The example uses the `Client` class implemented in +`ExternalIO/Client.hpp` to handle the communication, see +https://mp-spdz.readthedocs.io/en/latest/io.html#reference for +documentation. diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index ac56668a2..8ac400839 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -33,9 +33,6 @@ class FHE_Params int n_mults() const { return FFTData.size() - 1; } - // Rely on default copy assignment/constructor (not that they should - // ever be needed) - void set(const Ring& R,const vector& primes); void set(const vector& primes); void set_sec(int sec); diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index cb5daa386..c6e294a63 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -178,12 +178,6 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, return extra_slack; } - - - -/****************************************************************************** - * Here onwards needs NTL - ******************************************************************************/ @@ -345,6 +339,7 @@ ZZX Cyclotomic(int N) return F; } #else +// simplified version powers of two int phi_N(int N) { if (((N - 1) & N) != 0) diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index ab150d272..c0a2ecfea 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -1,8 +1,6 @@ #ifndef _NTL_Subs #define _NTL_Subs -/* All these routines use NTL on the inside */ - #include "FHE/Ring.h" #include "FHE/FFT_Data.h" #include "FHE/P2Data.h" @@ -47,7 +45,7 @@ class Parameters }; -// Main setup routine (need NTL if online_only is false) +// Main setup routine void generate_setup(int nparties, int lgp, int lg2, int sec, bool skip_2 = false, int slack = 0, bool round_up = false); @@ -60,7 +58,6 @@ int generate_semi_setup(int plaintext_length, int sec, int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, bool round_up); -// Everything else needs NTL void init(Ring& Rg, int m, bool generate_poly); void init(P2Data& P2D,const Ring& Rg); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index ae52fc62f..7ab8e5172 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -114,7 +114,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, cout << "n: " << n << endl; cout << "sec: " << sec << endl; cout << "sigma: " << this->sigma << endl; - cout << "h: " << h << endl; cout << "B_clean size: " << numBits(B_clean) << endl; cout << "B_scale size: " << numBits(B_scale) << endl; cout << "B_KS size: " << numBits(B_KS) << endl; diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 9c2545ed8..812560a3a 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -401,19 +401,29 @@ void Ring_Element::change_rep(RepType r) bool Ring_Element::equals(const Ring_Element& a) const { - if (element.empty() and a.element.empty()) - return true; - else if (element.empty() or a.element.empty()) - throw not_implemented(); - if (rep!=a.rep) { throw rep_mismatch(); } if (*FFTD!=*a.FFTD) { throw pr_mismatch(); } + + if (is_zero() or a.is_zero()) + return is_zero() and a.is_zero(); + for (int i=0; i<(*FFTD).phi_m(); i++) { if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } } return true; } +bool Ring_Element::is_zero() const +{ + if (element.empty()) + return true; + for (auto& x : element) + if (not ::isZero(x, FFTD->get_prD())) + return false; + return true; +} + + ConversionIterator Ring_Element::get_iterator() const { if (rep != polynomial) @@ -560,6 +570,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const { if (&FFTD != this->FFTD) throw params_mismatch(); + if (is_zero()) + throw runtime_error("element is zero"); } diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index 5cc93ca9a..5982bbe32 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -95,6 +95,7 @@ class Ring_Element void randomize(PRNG& G,bool Diag=false); bool equals(const Ring_Element& a) const; + bool is_zero() const; // This is a NOP in cases where we cannot do a FFT void change_rep(RepType r); diff --git a/FHEOffline/PairwiseGenerator.cpp b/FHEOffline/PairwiseGenerator.cpp index ed5fb303e..dcbd29b52 100644 --- a/FHEOffline/PairwiseGenerator.cpp +++ b/FHEOffline/PairwiseGenerator.cpp @@ -175,7 +175,7 @@ size_t PairwiseGenerator::report_size(ReportType type) template size_t PairwiseGenerator::report_sent() { - return P.sent; + return P.total_comm().sent; } template diff --git a/FHEOffline/SimpleGenerator.h b/FHEOffline/SimpleGenerator.h index 9cacad697..d5ee933af 100644 --- a/FHEOffline/SimpleGenerator.h +++ b/FHEOffline/SimpleGenerator.h @@ -71,7 +71,7 @@ class SimpleGenerator : public GeneratorBase void run(bool exhaust); size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); - size_t report_sent() { return P.sent; } + size_t report_sent() { return P.total_comm().sent; } }; #endif /* FHEOFFLINE_SIMPLEGENERATOR_H_ */ diff --git a/GC/BitAdder.hpp b/GC/BitAdder.hpp index 9f8525971..437af179a 100644 --- a/GC/BitAdder.hpp +++ b/GC/BitAdder.hpp @@ -96,7 +96,7 @@ void BitAdder::add(vector >& res, b[j] = summands[i][1][input_begin + j]; } - protocol.init_mul(&proc); + protocol.init_mul(); for (size_t j = 0; j < n_items; j++) { res[begin + j][i] = a[j] + b[j] + carries[j]; diff --git a/GC/CcdPrep.h b/GC/CcdPrep.h index 8d232444c..ab02ea802 100644 --- a/GC/CcdPrep.h +++ b/GC/CcdPrep.h @@ -91,11 +91,6 @@ class CcdPrep : public BufferPrep (typename T::clear(tmp.get_bit(0)) << i); } } - - NamedCommStats comm_stats() - { - return part_prep.comm_stats(); - } }; } /* namespace GC */ diff --git a/GC/CcdPrep.hpp b/GC/CcdPrep.hpp index f9535350b..3124efc42 100644 --- a/GC/CcdPrep.hpp +++ b/GC/CcdPrep.hpp @@ -25,6 +25,14 @@ void CcdPrep::set_protocol(typename T::Protocol& protocol) { auto& thread = ShareThread::s(); assert(thread.MC); + + if (part_proc) + { + assert(&part_proc->MC == &thread.MC->get_part_MC()); + assert(&part_proc->P == &protocol.get_part().P); + return; + } + part_proc = new SubProcessor( thread.MC->get_part_MC(), part_prep, protocol.get_part().P); } diff --git a/GC/CcdShare.h b/GC/CcdShare.h index aececad0a..e890ce633 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -27,6 +27,7 @@ class CcdShare : public ShamirShare, public ShareSecret> typedef ShamirInput Input; typedef ShamirMC MAC_Check; + typedef Shamir Protocol; typedef This small_type; diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 55c537de3..00e6c52c9 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -108,6 +108,9 @@ class FakeSecret : public ShareInterface, public BitVec template static void convcbit2s(GC::Processor&, const BaseInstruction&) { throw runtime_error("convcbit2s not implemented"); } + template + static void andm(GC::Processor&, const BaseInstruction&) + { throw runtime_error("andm not implemented"); } static FakeSecret input(GC::Processor& processor, const InputArgs& args); static FakeSecret input(int from, word input, int n_bits); diff --git a/GC/Instruction.cpp b/GC/Instruction.cpp index 3fe0cc588..6be1eb1ab 100644 --- a/GC/Instruction.cpp +++ b/GC/Instruction.cpp @@ -84,7 +84,7 @@ void Instruction::parse(istream& s, int pos) ostringstream os; os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl; os << "This virtual machine executes binary circuits only." << endl; - os << "Try compiling with '-B' or use only sbit* types." << endl; + os << "Use 'compile.py -B'." << endl; throw Invalid_Instruction(os.str()); break; } diff --git a/GC/NoShare.h b/GC/NoShare.h index f60eccd75..c435ec3f8 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -7,6 +7,7 @@ #define GC_NOSHARE_H_ #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "Protocols/ShareInterface.h" class InputArgs; @@ -148,11 +149,14 @@ class NoShare : public ShareInterface static void trans(Processor&, Integer, const vector&) { fail(); } + static void andm(GC::Processor&, const BaseInstruction&) { fail(); } + static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; } NoShare() {} - NoShare(int) { fail(); } + template + NoShare(T) { fail(); } void load_clear(Integer, Integer) { fail(); } void random_bit() { fail(); } diff --git a/GC/PostSacriBin.cpp b/GC/PostSacriBin.cpp index 81341cf08..742480600 100644 --- a/GC/PostSacriBin.cpp +++ b/GC/PostSacriBin.cpp @@ -9,6 +9,7 @@ #include "Protocols/Replicated.hpp" #include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "ShareSecret.hpp" namespace GC @@ -28,24 +29,19 @@ PostSacriBin::~PostSacriBin() } } -void PostSacriBin::init_mul(SubProcessor* proc) -{ - assert(proc != 0); - init_mul(proc->DataF, proc->MC); -} - -void PostSacriBin::init_mul(Preprocessing&, T::MC&) +void PostSacriBin::init_mul() { if ((int) inputs.size() >= OnlineOptions::singleton.batch_size) check(); honest.init_mul(); } -PostSacriBin::T::clear PostSacriBin::prepare_mul(const T& x, const T& y, int n) +void PostSacriBin::prepare_mul(const T& x, const T& y, int n) { + if (n == -1) + n = T::default_length; honest.prepare_mul(x, y, n); inputs.push_back({{x.mask(n), y.mask(n)}}); - return {}; } void PostSacriBin::exchange() @@ -55,6 +51,8 @@ void PostSacriBin::exchange() PostSacriBin::T PostSacriBin::finalize_mul(int n) { + if (n == -1) + n = T::default_length; auto res = honest.finalize_mul(n); outputs.push_back({res, n}); return res; diff --git a/GC/PostSacriBin.h b/GC/PostSacriBin.h index 50baa9c5d..8f1643a76 100644 --- a/GC/PostSacriBin.h +++ b/GC/PostSacriBin.h @@ -38,9 +38,8 @@ class PostSacriBin : public ReplicatedBase, PostSacriBin(Player& P); ~PostSacriBin(); - void init_mul(Preprocessing&, T::MC&); - void init_mul(SubProcessor* proc); - T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); diff --git a/GC/RepPrep.hpp b/GC/RepPrep.hpp index 1c91fd395..f83fbdaf4 100644 --- a/GC/RepPrep.hpp +++ b/GC/RepPrep.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_REPPREP_HPP_ +#define GC_REPPREP_HPP_ + #include "RepPrep.h" #include "ShareThread.h" #include "Processor/OnlineOptions.h" @@ -98,3 +101,5 @@ void RepPrep::buffer_inputs(int player) } } /* namespace GC */ + +#endif diff --git a/GC/Secret.h b/GC/Secret.h index 14f6638af..c4b6e8eb1 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -126,6 +126,9 @@ class Secret template static void convcbit2s(Processor& processor, const BaseInstruction& instruction) { T::convcbit2s(processor, instruction); } + template + static void andm(Processor& processor, const BaseInstruction& instruction) + { T::andm(processor, instruction); } Secret(); Secret(const Integer& x) { *this = x; } diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 9fc3f4918..9eed3b316 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -24,12 +24,15 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) : void SemiPrep::set_protocol(Beaver& protocol) { if (triple_generator) + { + assert(&triple_generator->get_player() == &protocol.P); return; + } (void) protocol; params.set_passive(); triple_generator = new SemiSecret::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), + BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1, OnlineOptions::singleton.batch_size, 1, params, {}, &protocol.P); triple_generator->multi_threaded = false; @@ -61,12 +64,4 @@ void SemiPrep::buffer_bits() } } -NamedCommStats SemiPrep::comm_stats() -{ - if (triple_generator) - return triple_generator->comm_stats(); - else - return {}; -} - } /* namespace GC */ diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 97214c28d..737cfb986 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -44,6 +44,8 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer get_triple_no_count(int n_bits) { + if (n_bits == -1) + n_bits = SemiSecret::default_length; return ShiftableTripleBuffer::get_triple_no_count(n_bits); } @@ -51,8 +53,6 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer static void convcbit2s(Processor& processor, const BaseInstruction& instruction) { processor.convcbit2s(instruction); } + static void andm(Processor& processor, const BaseInstruction& instruction) + { processor.andm(instruction); } static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 1a508828b..23c86cb28 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -47,7 +47,7 @@ void ShareSecret::invert(int n, const U& x) { U ones; ones.load_clear(64, -1); - static_cast(*this) = U(x ^ ones) & get_mask(n); + reinterpret_cast(*this) = U(x + ones) & get_mask(n); } template @@ -92,8 +92,12 @@ template void ShareSecret::store_clear_in_dynamic(Memory& mem, const vector& accesses) { + auto& thread = ShareThread::s(); + assert(thread.P); + assert(thread.MC); for (auto access : accesses) - mem[access.address] = access.value; + mem[access.address] = U::constant(access.value, thread.P->my_num(), + thread.MC->get_alphai()); } template @@ -330,7 +334,7 @@ void ShareSecret::random_bit() template U& GC::ShareSecret::operator=(const U& other) { - U& real_this = static_cast(*this); + U& real_this = reinterpret_cast(*this); real_this = other; return real_this; } diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 5f995e808..42c5e3bd6 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -58,9 +58,6 @@ class StandaloneShareThread : public ShareThread, public Thread void pre_run(); void post_run() { ShareThread::post_run(); } - - NamedCommStats comm_stats() - { return Thread::comm_stats() + this->DataF.comm_stats(); } }; template diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 14d496115..07085040b 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -63,6 +63,7 @@ void ShareThread::pre_run(Player& P, typename T::mac_key_type mac_key) protocol = new typename T::Protocol(*this->P); MC = this->new_mc(mac_key); DataF.set_protocol(*this->protocol); + this->protocol->init(DataF, *MC); } template @@ -85,7 +86,7 @@ void ShareThread::and_(Processor& processor, { auto& protocol = this->protocol; processor.check_args(args, 4); - protocol->init_mul(DataF, *this->MC); + protocol->init_mul(); T x_ext, y_ext; for (size_t i = 0; i < args.size(); i += 4) { diff --git a/GC/Thread.h b/GC/Thread.h index 659c070a0..6631ad723 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -55,8 +55,6 @@ class Thread void join_tape(); void finish(); - - virtual NamedCommStats comm_stats(); }; template diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 5487c41b2..d0b515cbf 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -96,13 +96,6 @@ void Thread::finish() pthread_join(thread, 0); } -template -NamedCommStats Thread::comm_stats() -{ - assert(P); - return P->comm_stats; -} - } /* namespace GC */ diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 060e9f118..c6c9dcaac 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -95,11 +95,11 @@ void ThreadMaster::run() post_run(); - NamedCommStats stats = P->comm_stats; + NamedCommStats stats = P->total_comm(); ExecutionStats exe_stats; for (auto thread : threads) { - stats += thread->P->comm_stats; + stats += thread->P->total_comm(); exe_stats += thread->processor.stats; delete thread; } diff --git a/GC/TinierSharePrep.h b/GC/TinierSharePrep.h index 34beaf6fb..4e316e38c 100644 --- a/GC/TinierSharePrep.h +++ b/GC/TinierSharePrep.h @@ -44,8 +44,6 @@ class TinierSharePrep : public PersonalPrep ~TinierSharePrep(); void set_protocol(typename T::Protocol& protocol); - - NamedCommStats comm_stats(); }; } diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp index 57e759b9c..e136ec446 100644 --- a/GC/TinierSharePrep.hpp +++ b/GC/TinierSharePrep.hpp @@ -8,7 +8,7 @@ #include "TinierSharePrep.h" -#include "PersonalPrep.hpp" +#include "PersonalPrep.h" namespace GC { @@ -39,14 +39,17 @@ template void TinierSharePrep::set_protocol(typename T::Protocol& protocol) { if (triple_generator) + { + assert(&triple_generator->get_player() == &protocol.P); return; + } params.generateMACs = true; params.amplify = false; params.check = false; auto& thread = ShareThread::s(); triple_generator = new typename T::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, + BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1, OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &protocol.P); triple_generator->multi_threaded = false; @@ -84,17 +87,6 @@ void GC::TinierSharePrep::buffer_bits() BufferPrep::get_random_from_inputs(thread.P->num_players())); } -template -NamedCommStats TinierSharePrep::comm_stats() -{ - NamedCommStats res; - if (triple_generator) - res += triple_generator->comm_stats(); - if (real_triple_generator) - res += real_triple_generator->comm_stats(); - return res; -} - } #endif diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 2b8a11b79..897b3b482 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -16,7 +16,7 @@ void TinierSharePrep::init_real(Player& P) assert(real_triple_generator == 0); auto& thread = ShareThread::s(); real_triple_generator = new typename T::whole_type::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), P.N, -1, + BaseMachine::fresh_ot_setup(P), P.N, -1, OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &P); real_triple_generator->multi_threaded = false; diff --git a/GC/VectorInput.h b/GC/VectorInput.h index c17cd93d4..44c9591b9 100644 --- a/GC/VectorInput.h +++ b/GC/VectorInput.h @@ -36,6 +36,8 @@ class VectorInput : public InputBase void add_mine(const typename T::open_type& input, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; for (int i = 0; i < n_bits; i++) part_input.add_mine(input.get_bit(i)); input_lengths.push_back(n_bits); @@ -43,6 +45,8 @@ class VectorInput : public InputBase void add_other(int player, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; for (int i = 0; i < n_bits; i++) part_input.add_other(player); } @@ -69,6 +73,8 @@ class VectorInput : public InputBase void finalize_other(int player, T& target, octetStream&, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; target.resize_regs(n_bits); for (int i = 0; i < n_bits; i++) part_input.finalize_other(player, target.get_reg(i), diff --git a/GC/VectorProtocol.h b/GC/VectorProtocol.h index 3f7e203c5..94ef19893 100644 --- a/GC/VectorProtocol.h +++ b/GC/VectorProtocol.h @@ -21,9 +21,10 @@ class VectorProtocol : public ProtocolBase VectorProtocol(Player& P); - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init(Preprocessing& prep, typename T::MAC_Check& MC); + + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); void finalize_mult(T& res, int n = -1); T finalize_mul(int n = -1); diff --git a/GC/VectorProtocol.hpp b/GC/VectorProtocol.hpp index cae461812..e72e0d148 100644 --- a/GC/VectorProtocol.hpp +++ b/GC/VectorProtocol.hpp @@ -18,26 +18,26 @@ VectorProtocol::VectorProtocol(Player& P) : } template -void VectorProtocol::init_mul(SubProcessor* proc) +void VectorProtocol::init(Preprocessing& prep, + typename T::MAC_Check& MC) { - assert(proc); - init_mul(proc->DataF, proc->MC); + part_protocol.init(prep.get_part(), MC.get_part_MC()); } template -void VectorProtocol::init_mul(Preprocessing& prep, - typename T::MAC_Check& MC) +void VectorProtocol::init_mul() { - part_protocol.init_mul(prep.get_part(), MC.get_part_MC()); + part_protocol.init_mul(); } template -typename T::clear VectorProtocol::prepare_mul(const T& x, +void VectorProtocol::prepare_mul(const T& x, const T& y, int n) { + if (n == -1) + n = T::default_length; for (int i = 0; i < n; i++) part_protocol.prepare_mul(x.get_reg(i), y.get_reg(i), 1); - return {}; } template @@ -57,6 +57,8 @@ T VectorProtocol::finalize_mul(int n) template void VectorProtocol::finalize_mult(T& res, int n) { + if (n == -1) + n = T::default_length; res.resize_regs(n); for (int i = 0; i < n; i++) res.get_reg(i) = part_protocol.finalize_mul(1); diff --git a/GC/instructions.h b/GC/instructions.h index fc278d441..66ae46d22 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -46,6 +46,7 @@ X(NOTCB, processor.notcb(INST)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ X(ANDS, T::ands(PROC, EXTRA)) \ + X(ANDM, T::andm(PROC, instruction)) \ X(ADDCB, C0 = PC1 + PC2) \ X(ADDCBI, C0 = PC1 + int(IMM)) \ X(MULCBI, C0 = PC1 * int(IMM)) \ @@ -76,7 +77,6 @@ #define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \ X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \ X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \ - X(ANDM, processor.andm(instruction)) \ X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \ X(CONVCINT, C0 = Proc.read_Ci(REG1)) \ X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ diff --git a/License.txt b/License.txt index 3a9eb2ae0..ccaafe01e 100644 --- a/License.txt +++ b/License.txt @@ -1,5 +1,5 @@ CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) -Copyright (c) 2021, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. +Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material. Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. diff --git a/Machines/Atlas.hpp b/Machines/Atlas.hpp new file mode 100644 index 000000000..045b69b9e --- /dev/null +++ b/Machines/Atlas.hpp @@ -0,0 +1,16 @@ +/* + * Atlas.hpp + * + */ + +#ifndef MACHINES_ATLAS_HPP_ +#define MACHINES_ATLAS_HPP_ + +#include "Protocols/AtlasShare.h" +#include "Protocols/AtlasPrep.h" +#include "GC/AtlasSecret.h" + +#include "ShamirMachine.hpp" +#include "Protocols/Atlas.hpp" + +#endif /* MACHINES_ATLAS_HPP_ */ diff --git a/Machines/Rep.hpp b/Machines/Rep.hpp index d37c385c5..a480860f8 100644 --- a/Machines/Rep.hpp +++ b/Machines/Rep.hpp @@ -4,6 +4,7 @@ */ #include "Protocols/MalRepRingPrep.h" +#include "Protocols/ReplicatedPrep2k.h" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" diff --git a/Machines/Rep4.hpp b/Machines/Rep4.hpp new file mode 100644 index 000000000..83ad1cff5 --- /dev/null +++ b/Machines/Rep4.hpp @@ -0,0 +1,17 @@ +/* + * Rep4.hpp + * + */ + +#ifndef MACHINES_REP4_HPP_ +#define MACHINES_REP4_HPP_ + +#include "GC/Rep4Secret.h" +#include "Protocols/Rep4Share2k.h" +#include "Protocols/Rep4Prep.h" +#include "Protocols/Rep4.hpp" +#include "Protocols/Rep4MC.hpp" +#include "Protocols/Rep4Input.hpp" +#include "Protocols/Rep4Prep.hpp" + +#endif /* MACHINES_REP4_HPP_ */ diff --git a/Machines/SPDZ.hpp b/Machines/SPDZ.hpp index 02ad9b983..a221b087a 100644 --- a/Machines/SPDZ.hpp +++ b/Machines/SPDZ.hpp @@ -21,13 +21,15 @@ #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" +#include "GC/VectorProtocol.h" -#include "GC/ShareParty.hpp" +#include "GC/ShareParty.h" #include "GC/Secret.hpp" -#include "GC/TinyPrep.hpp" -#include "GC/ShareSecret.hpp" -#include "GC/TinierSharePrep.hpp" -#include "GC/CcdPrep.hpp" +#include "GC/ShareSecret.h" +#include "GC/TinierSharePrep.h" +#include "GC/CcdPrep.h" + +#include "GC/VectorProtocol.hpp" #include "Math/gfp.hpp" diff --git a/Machines/SPDZ2k.hpp b/Machines/SPDZ2k.hpp index 672a29b4e..6cb02779d 100644 --- a/Machines/SPDZ2k.hpp +++ b/Machines/SPDZ2k.hpp @@ -23,9 +23,10 @@ #include "Protocols/MascotPrep.hpp" #include "Protocols/Spdz2kPrep.hpp" -#include "GC/ShareParty.hpp" -#include "GC/ShareSecret.hpp" +#include "GC/ShareParty.h" +#include "GC/ShareSecret.h" #include "GC/Secret.hpp" -#include "GC/TinyPrep.hpp" -#include "GC/TinierSharePrep.hpp" -#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.h" +#include "GC/CcdPrep.h" + +#include "GC/VectorProtocol.hpp" diff --git a/Machines/Semi.hpp b/Machines/Semi.hpp index 36c9d8c50..1a0931467 100644 --- a/Machines/Semi.hpp +++ b/Machines/Semi.hpp @@ -18,3 +18,4 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" diff --git a/Machines/Semi2k.hpp b/Machines/Semi2k.hpp new file mode 100644 index 000000000..56f86d9ba --- /dev/null +++ b/Machines/Semi2k.hpp @@ -0,0 +1,15 @@ +/* + * Semi2.hpp + * + */ + +#ifndef MACHINES_SEMI2K_HPP_ +#define MACHINES_SEMI2K_HPP_ + +#include "Protocols/Semi2kShare.h" +#include "Protocols/SemiPrep2k.h" + +#include "Semi.hpp" +#include "Protocols/RepRingOnlyEdabitPrep.hpp" + +#endif /* MACHINES_SEMI2K_HPP_ */ diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 080332aea..7697c5124 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -27,6 +27,7 @@ #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" #include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/VectorProtocol.hpp" #include "GC/Secret.hpp" diff --git a/Machines/Tinier.cpp b/Machines/Tinier.cpp new file mode 100644 index 000000000..99ad1c5c1 --- /dev/null +++ b/Machines/Tinier.cpp @@ -0,0 +1,23 @@ +/* + * Tinier.cpp + * + */ + +#include "GC/TinyMC.h" +#include "GC/TinierSecret.h" +#include "GC/VectorInput.h" + +#include "GC/ShareParty.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/CcdPrep.hpp" +#include "GC/PersonalPrep.hpp" + +//template class GC::ShareParty>; +template class GC::CcdPrep>; +template class Preprocessing>; +template class GC::TinierSharePrep>; +template class GC::ShareSecret>; +template class TripleShuffleSacrifice>; diff --git a/Machines/atlas-party.cpp b/Machines/atlas-party.cpp index 6e754c7ff..2df033e60 100644 --- a/Machines/atlas-party.cpp +++ b/Machines/atlas-party.cpp @@ -3,12 +3,7 @@ * */ -#include "Protocols/AtlasShare.h" -#include "Protocols/AtlasPrep.h" -#include "GC/AtlasSecret.h" - -#include "ShamirMachine.hpp" -#include "Protocols/Atlas.hpp" +#include "Atlas.hpp" int main(int argc, const char** argv) { diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 8525b0671..f26f5f324 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -10,11 +10,13 @@ #include "Processor/RingOptions.h" #include "Processor/Machine.hpp" +#include "Processor/OnlineOptions.hpp" #include "Math/Z2k.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/ShuffleSacrifice.hpp" #include "Protocols/ReplicatedPrep.hpp" #include "Protocols/FakeShare.hpp" +#include "Protocols/MalRepRingPrep.hpp" int main(int argc, const char** argv) { @@ -22,7 +24,7 @@ int main(int argc, const char** argv) Names N; ez::ezOptionParser opt; RingOptions ring_opts(opt, argc, argv); - online_opts = {opt, argc, argv}; + online_opts = {opt, argc, argv, FakeShare>()}; opt.parse(argc, argv); opt.syntax = string(argv[0]) + " "; @@ -44,9 +46,7 @@ int main(int argc, const char** argv) #ifdef ROUND_NEAREST_IN_EMULATION cerr << "Using nearest rounding instead of probabilistic truncation" << endl; #else -#ifdef RISKY_TRUNCATION_IN_EMULATION - cerr << "Using risky truncation" << endl; -#endif + online_opts.set_trunc_error(opt); #endif int R = ring_opts.ring_size_from_opts_or_schedule(progname); diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp index 471862dab..934c15dcd 100644 --- a/Machines/hemi-party.cpp +++ b/Machines/hemi-party.cpp @@ -24,6 +24,7 @@ #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Hemi.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" #include "Math/gfp.hpp" diff --git a/Machines/no-party.cpp b/Machines/no-party.cpp index 2120322f3..ce542de18 100644 --- a/Machines/no-party.cpp +++ b/Machines/no-party.cpp @@ -8,6 +8,7 @@ #include "Processor/OnlineMachine.hpp" #include "Processor/Machine.hpp" #include "Protocols/Replicated.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "Math/gfp.hpp" #include "Math/Z2k.hpp" diff --git a/Machines/soho-party.cpp b/Machines/soho-party.cpp index 6f7c70a3a..7ecc450da 100644 --- a/Machines/soho-party.cpp +++ b/Machines/soho-party.cpp @@ -22,6 +22,7 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" #include "Math/gfp.hpp" diff --git a/Makefile b/Makefile index 9d634c0e9..e40528b8c 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,8 @@ MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) VM = $(MINI_OT) $(SHAREDLIB) COMMON = $(SHAREDLIB) +TINIER = Machines/Tinier.o $(OT) +SPDZ = Machines/SPDZ.o $(TINIER) LIB = libSPDZ.a @@ -117,7 +119,7 @@ sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) Fake-ECDSA.x ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) -$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(OT) $(GC) +$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(TINIER) $(GC) $(AR) -csr $@ $^ CFLAGS += -fPIC @@ -203,16 +205,16 @@ ps-rep-bin-party.x: GC/PostSacriBin.o semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o tiny-party.x: $(OT) tinier-party.x: $(OT) -spdz2k-party.x: $(OT) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) +spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) -cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) -chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) -lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o -highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o +cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) +chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) +lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o +highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o atlas-party.x: GC/AtlasSecret.o static/hemi-party.x: $(FHEOBJS) static/soho-party.x: $(FHEOBJS) @@ -220,10 +222,10 @@ static/cowgear-party.x: $(FHEOBJS) static/chaigear-party.x: $(FHEOBJS) static/lowgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o static/highgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o -mascot-party.x: Machines/SPDZ.o $(OT) -static/mascot-party.x: Machines/SPDZ.o -Player-Online.x: Machines/SPDZ.o $(OT) -mama-party.x: $(OT) +mascot-party.x: $(SPDZ) +static/mascot-party.x: $(SPDZ) +Player-Online.x: $(SPDZ) +mama-party.x: $(TINIER) ps-rep-ring-party.x: Protocols/MalRepRingOptions.o malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o @@ -236,8 +238,10 @@ emulate.x: GC/FakeSecret.o semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(OT) real-bmr-party.x: $(OT) paper-example.x: $(VM) $(OT) $(FHEOFFLINE) -mascot-offline.x: $(VM) $(OT) -cowgear-offline.x: $(OT) $(FHEOFFLINE) +binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o +mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o Machines/Tinier.o +mascot-offline.x: $(VM) $(TINIER) +cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) static/mal-rep-bmr-party.x: $(BMR) static/shamir-bmr-party.x: $(BMR) diff --git a/Math/BitVec.h b/Math/BitVec.h index f9e874d14..f0d60a1b9 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -26,6 +26,7 @@ class BitVec_ : public IntBase static const false_type invertible; static const true_type characteristic_two; + static const true_type binary; static char type_char() { return 'B'; } static string type_short() { return "B"; } @@ -64,8 +65,21 @@ class BitVec_ : public IntBase void pack(octetStream& os) const { os.store_int(this->a); } void unpack(octetStream& os) { this->a = os.get_int(); } - void pack(octetStream& os, int n) const { os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); } - void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); } + void pack(octetStream& os, int n) const + { + if (n == -1) + pack(os); + else + os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); + } + + void unpack(octetStream& os, int n) + { + if (n == -1) + unpack(os); + else + this->a = os.get_int(DIV_CEIL(n, 8)); + } static BitVec_ unpack_new(octetStream& os, int n = n_bits) { @@ -81,5 +95,7 @@ template const false_type BitVec_::invertible; template const true_type BitVec_::characteristic_two; +template +const true_type BitVec_::binary; #endif /* MATH_BITVEC_H_ */ diff --git a/Math/Setup.hpp b/Math/Setup.hpp index 6545d67ec..91cafaea5 100644 --- a/Math/Setup.hpp +++ b/Math/Setup.hpp @@ -36,8 +36,9 @@ void read_setup(const string& dir_prefix, int lgp = -1) { if (lgp > 0) { - cerr << "No modulus found in " << filename << ", generating " << lgp - << "-bit prime" << endl; + if (OnlineOptions::singleton.verbose) + cerr << "No modulus found in " << filename << ", generating " + << lgp << "-bit prime" << endl; T::init_default(lgp); } else diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index d15af24c8..07807cb23 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -20,6 +20,7 @@ class ValueInterface static const false_type characteristic_two; static const false_type prime_field; static const false_type invertible; + static const false_type binary; template static void init(bool mont = true) { (void) mont; } diff --git a/Math/Z2k.h b/Math/Z2k.h index 3e6530442..ad32cbf16 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -47,6 +47,7 @@ class Z2 : public ValueInterface static int size_in_limbs() { return N_WORDS; } static int size_in_bits() { return size() * 8; } static int length() { return size_in_bits(); } + static int n_bits() { return N_BITS; } static int t() { return 0; } static char type_char() { return 'R'; } @@ -100,6 +101,8 @@ class Z2 : public ValueInterface int bit_length() const; + Z2 mask(int) const { return *this; } + Z2 operator+(const Z2& other) const; Z2 operator-(const Z2& other) const; diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 63c279a26..17fcdf24c 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -86,6 +86,42 @@ void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y,int t { inline_mpn_copyi(z,ans+t,t); } } +void Zp_Data::Mont_Mult_switch(mp_limb_t* z, const mp_limb_t* x, + const mp_limb_t* y) const +{ + switch (t) + { +#ifdef __BMI2__ +#define CASE(N) \ + case N: \ + Mont_Mult_(z, x, y); \ + break; + CASE(1) + CASE(2) +#if MAX_MOD_SZ >= 4 + CASE(3) + CASE(4) +#endif +#if MAX_MOD_SZ >= 5 + CASE(5) +#endif +#if MAX_MOD_SZ >= 6 + CASE(6) +#endif +#if MAX_MOD_SZ >= 10 + CASE(7) + CASE(8) + CASE(9) + CASE(10) +#endif +#undef CASE +#endif + default: + Mont_Mult_variable(z, x, y); + break; + } +} + ostream& operator<<(ostream& s,const Zp_Data& ZpD) diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 96deb7951..f30e71037 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -40,6 +40,7 @@ class Zp_Data template void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; + void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const; void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const { Mont_Mult(z, x, y, t); } @@ -242,37 +243,11 @@ inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* { if (not cpu_has_bmi2()) return Mont_Mult_variable(z, x, y); - switch (t) - { #ifdef __BMI2__ -#define CASE(N) \ - case N: \ - Mont_Mult_(z, x, y); \ - break; - CASE(1) - CASE(2) -#if MAX_MOD_SZ >= 4 - CASE(3) - CASE(4) -#endif -#if MAX_MOD_SZ >= 5 - CASE(5) -#endif -#if MAX_MOD_SZ >= 6 - CASE(6) -#endif -#if MAX_MOD_SZ >= 10 - CASE(7) - CASE(8) - CASE(9) - CASE(10) -#endif -#undef CASE + return Mont_Mult_switch(z, x, y); +#else + return Mont_Mult_variable(z, x, y); #endif - default: - Mont_Mult_variable(z, x, y); - break; - } } inline void Zp_Data::Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x, diff --git a/Math/gfp.h b/Math/gfp.h index 7b257b5fa..bde43025e 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -11,7 +11,6 @@ using namespace std; #include "Math/Bit.h" #include "Math/Setup.h" #include "Tools/random.h" -#include "GC/NoShare.h" #include "Processor/OnlineOptions.h" #include "Math/modp.hpp" @@ -101,6 +100,7 @@ class gfp_ : public ValueInterface static int size() { return t() * sizeof(mp_limb_t); } static int size_in_bits() { return 8 * size(); } static int length() { return ZpD.pr_bit_length; } + static int n_bits() { return length() - 1; } static void reqbl(int n); diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index c2e1403b9..9d8da6514 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -5,6 +5,7 @@ #include "CryptoPlayer.h" #include "Math/Setup.h" +#include "Tools/Bundle.h" void check_ssl_file(string filename) { @@ -124,12 +125,14 @@ CryptoPlayer::~CryptoPlayer() void CryptoPlayer::send_to_no_stats(int other, const octetStream& o) const { + assert(other != my_num()); senders[other]->request(o); senders[other]->wait(o); } void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const { + assert(other != my_num()); receivers[other]->request(o); receivers[other]->wait(o); } @@ -137,6 +140,7 @@ void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send, octetStream& to_receive) const { + assert(other != my_num()); if (&to_send == &to_receive) { MultiPlayer::exchange_no_stats(other, to_send, to_receive); @@ -153,6 +157,7 @@ void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send, void CryptoPlayer::pass_around_no_stats(const octetStream& to_send, octetStream& to_receive, int offset) const { + assert(get_player(offset) != my_num()); if (&to_send == &to_receive) { MultiPlayer::pass_around_no_stats(to_send, to_receive, offset); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 61c8fd65c..cd92df541 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -14,12 +14,14 @@ using namespace std; -void Names::init(int player,int pnb,int my_port,const char* servername) +void Names::init(int player, int pnb, int my_port, const char* servername, + bool setup_socket) { player_no=player; portnum_base=pnb; setup_names(servername, my_port); - setup_server(); + if (setup_socket) + setup_server(); } Names::Names(int player, int nplayers, const string& servername, int pnb, @@ -124,7 +126,7 @@ void Names::setup_names(const char *servername, int my_port) my_port = default_port(player_no); int socket_num; - int pn = portnum_base - 1; + int pn = portnum_base; set_up_client_socket(socket_num, servername, pn); octetStream("P" + to_string(player_no)).Send(socket_num); #ifdef DEBUG_NETWORKING @@ -132,15 +134,11 @@ void Names::setup_names(const char *servername, int my_port) #endif // Send my name - octet my_name[512]; - memset(my_name,0,512*sizeof(octet)); sockaddr_in address; socklen_t size = sizeof address; getsockname(socket_num, (sockaddr*)&address, &size); - char* name = inet_ntoa(address.sin_addr); - // max length of IP address with ending 0 - strncpy((char*)my_name, name, 16); - send(socket_num,my_name,512); + char* my_name = inet_ntoa(address.sin_addr); + octetStream(my_name).Send(socket_num); send(socket_num,(octet*)&my_port,4); #ifdef DEBUG_NETWORKING fprintf(stderr, "My Name = %s\n",my_name); @@ -158,9 +156,10 @@ void Names::setup_names(const char *servername, int my_port) names.resize(nplayers); ports.resize(nplayers); for (i=0; iinit(); } +void Names::set_server(ServerSocket* socket) +{ + assert(not server); + server = socket; +} + Names::Names(const Names& other) { @@ -201,6 +206,7 @@ Player::Player(const Names& Nms) : { nplayers=Nms.nplayers; player_no=Nms.player_no; + thread_stats.resize(nplayers); } @@ -243,6 +249,10 @@ MultiPlayer::~MultiPlayer() Player::~Player() { +#ifdef VERBOSE + for (auto& x : thread_stats) + x.print(); +#endif } PlayerBase::~PlayerBase() @@ -685,7 +695,7 @@ void VirtualTwoPartyPlayer::send(octetStream& o) const { TimeScope ts(comm_stats["Sending one-to-one"].add(o)); P.send_to_no_stats(other_player, o); - sent += o.get_length(); + comm_stats.sent += o.get_length(); } void RealTwoPartyPlayer::receive(octetStream& o) const @@ -729,12 +739,13 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const void VirtualTwoPartyPlayer::send_receive_player(vector& o) const { TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0])); - sent += o[0].get_length(); + comm_stats.sent += o[0].get_length(); P.exchange_no_stats(other_player, o[0], o[1]); } VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) : - TwoPartyPlayer(P.my_num()), P(P), other_player(other_player) + TwoPartyPlayer(P.my_num()), P(P), other_player(other_player), comm_stats( + P.thread_stats.at(other_player)) { } @@ -814,5 +825,13 @@ void NamedCommStats::print(bool newline) cerr << endl; } +NamedCommStats Player::total_comm() const +{ + auto res = comm_stats; + for (auto& x : thread_stats) + res += x; + return res; +} + template class MultiPlayer; template class MultiPlayer ; diff --git a/Networking/Player.h b/Networking/Player.h index 033aa3bd1..9c90dbd1f 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -35,6 +35,7 @@ class Names friend class Player; friend class PlainPlayer; friend class RealTwoPartyPlayer; + friend class Server; vector names; vector ports; @@ -51,6 +52,8 @@ class Names void setup_server(); + void set_server(ServerSocket* socket); + public: static const int DEFAULT_PORT = -1; @@ -62,8 +65,10 @@ class Names * @param my_port my port number (`DEFAULT_PORT` for default, * which is base port number plus player number) * @param servername location of server + * @param setup_socket whether to start listening */ - void init(int player,int pnb,int my_port,const char* servername); + void init(int player, int pnb, int my_port, const char* servername, + bool setup_socket = true); Names(int player,int pnb,int my_port,const char* servername) : Names() { init(player,pnb,my_port,servername); } @@ -172,11 +177,12 @@ class PlayerBase protected: int player_no; -public: size_t& sent; - mutable Timer timer; mutable NamedCommStats comm_stats; +public: + mutable Timer timer; + PlayerBase(int player_no) : player_no(player_no), sent(comm_stats.sent) {} virtual ~PlayerBase(); @@ -205,6 +211,8 @@ class Player : public PlayerBase public: const Names& N; + mutable vector thread_stats; + Player(const Names& Nms); virtual ~Player(); @@ -358,6 +366,8 @@ class Player : public PlayerBase virtual void request_receive(int i, octetStream& o) const { (void)i; (void)o; } virtual void wait_receive(int i, octetStream& o) const { receive_player(i, o); } + + NamedCommStats total_comm() const; }; /** @@ -500,6 +510,7 @@ class VirtualTwoPartyPlayer : public TwoPartyPlayer { Player& P; int other_player; + NamedCommStats& comm_stats; public: VirtualTwoPartyPlayer(Player& P, int other_player); diff --git a/Networking/Receiver.cpp b/Networking/Receiver.cpp index e93f47c44..7e8c93fe9 100644 --- a/Networking/Receiver.cpp +++ b/Networking/Receiver.cpp @@ -51,9 +51,17 @@ void Receiver::run() while (in.pop(os)) { os->reset_write_head(); +#ifdef VERBOSE_SSL timer.start(); + RunningTimer mytimer; +#endif os->Receive(socket); +#ifdef VERBOSE_SSL + cout << "receiving " << os->get_length() * 1e-6 << " MB on " << socket + << " took " << mytimer.elapsed() << ", total " + << timer.elapsed() << endl; timer.stop(); +#endif out.push(os); } } diff --git a/Networking/Sender.cpp b/Networking/Sender.cpp index 51d5f4711..4e4b98810 100644 --- a/Networking/Sender.cpp +++ b/Networking/Sender.cpp @@ -47,9 +47,17 @@ void Sender::run() const octetStream* os = 0; while (in.pop(os)) { -// timer.start(); +#ifdef VERBOSE_SSL + timer.start(); + RunningTimer mytimer; +#endif os->Send(socket); -// timer.stop(); +#ifdef VERBOSE_SSL + cout << "sending " << os->get_length() * 1e-6 << " MB on " << socket + << " took " << mytimer.elapsed() << ", total " + << timer.elapsed() << endl; + timer.stop(); +#endif out.push(os); } } diff --git a/Networking/Server.cpp b/Networking/Server.cpp index d9a056dd2..facda0a26 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -28,9 +28,7 @@ void Server::get_ip(int num) inet_ntop(AF_INET6, &s->sin6_addr, ipstr, sizeof ipstr); } - names[num]=new octet[512]; - memset(names[num], 0, 512); - strncpy((char*)names[num], ipstr, INET6_ADDRSTRLEN); + names[num] = ipstr; #ifdef DEBUG_NETWORKING cerr << "Client IP address: " << names[num] << endl; @@ -45,11 +43,11 @@ void Server::get_name(int num) #endif // Receive name sent by client (legacy) - not used here - octet my_name[512]; - receive(socket_num[num],my_name,512); + octetStream os; + os.Receive(socket_num[num]); receive(socket_num[num],(octet*)&ports[num],4); #ifdef DEBUG_NETWORKING - cerr << "Player " << num << " sent (IP for info only) " << my_name << ":" + cerr << "Player " << num << " sent (IP for info only) " << os.str() << ":" << ports[num] << endl; #endif @@ -66,7 +64,7 @@ void Server::send_names(int num) send(socket_num[num],nmachines,4); for (int i=0; i= 0); assert(my_num < nplayers); @@ -172,12 +175,19 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, { pthread_create(&thread, 0, Server::start_in_thread, server = new Server(nplayers, portnum)); - } - N.init(my_num, portnum, my_port, hostname.c_str()); - if (my_num == 0) - { + N.init(my_num, portnum, my_port, hostname.c_str(), false); pthread_join(thread, 0); + N.set_server(server->get_socket()); delete server; } + else + N.init(my_num, portnum, my_port, hostname.c_str()); return 0; } + +ServerSocket* Server::get_socket() +{ + auto res = server_socket; + server_socket = 0; + return res; +} diff --git a/Networking/Server.h b/Networking/Server.h index a5e833add..ad6d5fd5d 100644 --- a/Networking/Server.h +++ b/Networking/Server.h @@ -14,10 +14,11 @@ using namespace std; class Server { vector socket_num; - vector names; + vector names; vector ports; int nmachines; int PortnumBase; + ServerSocket* server_socket; void get_ip(int num); void get_name(int num); @@ -31,7 +32,11 @@ class Server Server(int argc, char** argv); Server(int nmachines, int PortnumBase); + ~Server(); + void start(); + + ServerSocket* get_socket(); }; #endif /* NETWORKING_SERVER_H_ */ diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index 8989a0a10..79cb35222 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -7,6 +7,7 @@ #define CRYPTO_SSL_SOCKETS_H_ #include "Tools/int.h" +#include "Tools/time-func.h" #include "sockets.h" #include "Math/Setup.h" @@ -46,6 +47,10 @@ class ssl_socket : public boost::asio::ssl::stream string me, bool client) : parent(io_service, ctx) { +#ifdef DEBUG_NETWORKING + cerr << me << " setting up SSL to " << other << " as " << + (client ? "client" : "server") << endl; +#endif lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket); set_verify_mode(boost::asio::ssl::verify_peer); set_verify_callback(boost::asio::ssl::rfc2818_verification(other)); @@ -82,8 +87,16 @@ template<> inline void send(ssl_socket* socket, octet* data, size_t length) { size_t sent = 0; +#ifdef VERBOSE_SSL + RunningTimer timer; +#endif while (sent < length) + { sent += send_non_blocking(socket, data + sent, length - sent); +#ifdef VERBOSE_SSL + cout << "sent " << sent * 1e-6 << " MB at " << timer.elapsed() << endl; +#endif + } } template<> diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 8847728e9..988565854 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -1,6 +1,7 @@ #include "OT/BaseOT.h" #include "Tools/random.h" #include "Tools/benchmarking.h" +#include "Tools/Bundle.h" #include #include @@ -78,6 +79,23 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE rol void BaseOT::exec_base(bool new_receiver_inputs) { + Bundle bundle(*P); +#ifdef NO_AVX_OT + bundle.mine = string("OT without AVX"); +#else + bundle.mine = string("OT with AVX"); +#endif + try + { + bundle.compare(*P); + } + catch (mismatch_among_parties&) + { + cerr << "Parties compiled with different base OT algorithms" << endl; + cerr << "Set \"AVX_OT\" to the same value on all parties" << endl; + exit(1); + } + #ifdef NO_AVX_OT #ifdef USE_RISTRETTO typedef CurveElement Element; diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index d5981e713..8a84ca0a3 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -116,7 +116,7 @@ class OTTripleGenerator : public GeneratorThread mac_key_type get_mac_key() const { return mac_key; } - NamedCommStats comm_stats(); + Player& get_player() { return globalPlayer; } }; template @@ -209,15 +209,4 @@ class Spdz2kTripleGenerator : public NPartyTripleGenerator void generateTriples(); }; -template -NamedCommStats OTTripleGenerator::comm_stats() -{ - NamedCommStats res; - if (parentPlayer != &globalPlayer) - res = globalPlayer.comm_stats; - for (auto& player : players) - res += player->comm_stats; - return res; -} - #endif diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index bc36a8606..019fc6f28 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -110,22 +110,31 @@ void BaseMachine::time() void BaseMachine::start(int n) { cout << "Starting timer " << n << " at " << timer[n].elapsed() + << " (" << timer[n].mb_sent() << " MB)" << " after " << timer[n].idle() << endl; - timer[n].start(); + timer[n].start(total_comm()); } void BaseMachine::stop(int n) { - timer[n].stop(); - cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl; + timer[n].stop(total_comm()); + cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " (" + << timer[n].mb_sent() << " MB)" << endl; } void BaseMachine::print_timers() { + cerr << "The following timing is "; + if (OnlineOptions::singleton.live_prep) + cerr << "in"; + else + cerr << "ex"; + cerr << "clusive preprocessing." << endl; cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; timer.erase(0); - for (map::iterator it = timer.begin(); it != timer.end(); it++) - cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; + for (auto it = timer.begin(); it != timer.end(); it++) + cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds (" + << it->second.mb_sent() << " MB)" << endl; } string BaseMachine::memory_filename(const string& type_short, int my_number) @@ -170,3 +179,18 @@ bigint BaseMachine::prime_from_schedule(string progname) else return 0; } + +NamedCommStats BaseMachine::total_comm() +{ + NamedCommStats res; + for (auto& queue : queues) + res += queue->get_comm_stats(); + return res; +} + +void BaseMachine::set_thread_comm(const NamedCommStats& stats) +{ + auto queue = queues.at(BaseMachine::thread_num); + assert(queue); + queue->set_comm_stats(stats); +} diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 0e08549e3..035a0cfef 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -7,6 +7,7 @@ #define PROCESSOR_BASEMACHINE_H_ #include "Tools/time-func.h" +#include "Tools/TimerWithComm.h" #include "OT/OTTripleSetup.h" #include "ThreadJob.h" #include "ThreadQueues.h" @@ -22,7 +23,7 @@ class BaseMachine protected: static BaseMachine* singleton; - std::map timer; + std::map timer; string compiler; string domain; @@ -66,12 +67,18 @@ class BaseMachine virtual void reqbl(int) {} - OTTripleSetup fresh_ot_setup(); + static OTTripleSetup fresh_ot_setup(Player& P); + + NamedCommStats total_comm(); + void set_thread_comm(const NamedCommStats& stats); }; -inline OTTripleSetup BaseMachine::fresh_ot_setup() +inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) { - return ot_setups.at(thread_num).get_fresh(); + if (singleton and size_t(thread_num) < s().ot_setups.size()) + return s().ot_setups.at(thread_num).get_fresh(); + else + return OTTripleSetup(P, true); } #endif /* PROCESSOR_BASEMACHINE_H_ */ diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index be1fb8fdb..9878f4a6b 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -38,7 +38,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, int size_in_bytes = T::size() * buffer.size(); int n_read = 0; - char * read_buffer = new char[size_in_bytes]; + char read_buffer[size_in_bytes]; inf.seekg(start_posn * T::size()); do { diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 8f44ed253..8d05747e0 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -89,6 +89,7 @@ template class Processor; template class Data_Files; template class Machine; template class SubProcessor; +template class NoFilePrep; /** * Abstract base class for preprocessing @@ -125,6 +126,7 @@ class Preprocessing : public PrepBase template static Preprocessing* get_new(Machine& machine, DataPositions& usage, SubProcessor* proc); + template static Preprocessing* get_new(bool live_prep, const Names& N, DataPositions& usage); static Preprocessing* get_live_prep(SubProcessor* proc, @@ -133,22 +135,21 @@ class Preprocessing : public PrepBase Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {} virtual ~Preprocessing() {} - virtual void set_protocol(typename T::Protocol& protocol) = 0; + virtual void set_protocol(typename T::Protocol&) {}; virtual void set_proc(SubProcessor* proc) { (void) proc; } virtual void seekg(DataPositions& pos) { (void) pos; } virtual void prune() {} virtual void purge() {} - virtual size_t data_sent() { return comm_stats().sent; } - virtual NamedCommStats comm_stats() { return {}; } - - virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0; - virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0; - virtual void get_one_no_count(Dtype dtype, T& a) = 0; - virtual void get_input_no_count(T& a, typename T::open_type& x, int i) = 0; - virtual void get_no_count(vector& S, DataTag tag, const vector& regs, - int vector_size) = 0; + virtual void get_three_no_count(Dtype, T&, T&, T&) + { throw not_implemented(); } + virtual void get_two_no_count(Dtype, T&, T&) { throw not_implemented(); } + virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); } + virtual void get_input_no_count(T&, typename T::open_type&, int) + { throw not_implemented() ; } + virtual void get_no_count(vector&, DataTag, const vector&, int) + { throw not_implemented(); } void get(Dtype dtype, T* a); void get_three(Dtype dtype, T& a, T& b, T& c); @@ -191,6 +192,9 @@ class Sub_Data_Files : public Preprocessing { template friend class Sub_Data_Files; + typedef typename conditional, NoFilePrep>::type part_type; + static int tuple_length(int dtype); BufferOwner buffers[N_DTYPE]; @@ -205,7 +209,7 @@ class Sub_Data_Files : public Preprocessing const string prep_data_dir; int thread_num; - Sub_Data_Files* part; + part_type* part; void buffer_edabits_with_queues(bool strict, int n_bits) { buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); } @@ -274,7 +278,7 @@ class Sub_Data_Files : public Preprocessing void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size); void get_dabit_no_count(T& a, typename T::bit_type& b); - Preprocessing& get_part(); + part_type& get_part(); }; template @@ -307,8 +311,6 @@ class Data_Files } void reset_usage() { usage.reset(); skipped.reset(); } - - NamedCommStats comm_stats(); }; template inline @@ -418,6 +420,7 @@ T Preprocessing::get_bit() template T Preprocessing::get_random() { + assert(not usage.inputs.empty()); return get_random_from_inputs(usage.inputs.size()); } @@ -429,10 +432,4 @@ inline void Data_Files::purge() DataFb.purge(); } -template -NamedCommStats Data_Files::comm_stats() -{ - return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats(); -} - #endif diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 3635dc0ac..359ff6207 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -3,6 +3,7 @@ #include "Processor/Data_Files.h" #include "Processor/Processor.h" +#include "Processor/NoFilePrep.h" #include "Protocols/dabit.h" #include "Math/Setup.h" #include "GC/BitPrepFiles.h" @@ -30,6 +31,7 @@ Preprocessing* Preprocessing::get_new( } template +template Preprocessing* Preprocessing::get_new( bool live_prep, const Names& N, DataPositions& usage) @@ -156,17 +158,7 @@ Data_Files::Data_Files(const Names& N) : template Data_Files::~Data_Files() { -#ifdef VERBOSE - if (DataFp.data_sent()) - cerr << "Sent for " << sint::type_string() << " preprocessing threads: " << - DataFp.data_sent() * 1e-6 << " MB" << endl; -#endif delete &DataFp; -#ifdef VERBOSE - if (DataF2.data_sent()) - cerr << "Sent for " << sgf2n::type_string() << " preprocessing threads: " << - DataF2.data_sent() * 1e-6 << " MB" << endl; -#endif delete &DataF2; delete &DataFb; } @@ -264,6 +256,8 @@ void Sub_Data_Files::purge() for (auto it : extended) it.second.purge(); dabit_buffer.purge(); + if (part != 0) + part->purge(); } template @@ -329,10 +323,10 @@ void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits, } template -Preprocessing& Sub_Data_Files::get_part() +typename Sub_Data_Files::part_type& Sub_Data_Files::get_part() { if (part == 0) - part = new Sub_Data_Files(my_num, num_players, + part = new part_type(my_num, num_players, get_prep_sub_dir(num_players), this->usage, thread_num); return *part; diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index 95bcd029a..b3ed5bc54 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -87,10 +87,10 @@ class DummyProtocol : public ProtocolBase { } - void init_mul(SubProcessor* = 0) + void init_mul() { } - typename T::clear prepare_mul(const T&, const T&, int = 0) + void prepare_mul(const T&, const T&, int = 0) { throw not_implemented(); } diff --git a/Processor/FieldMachine.h b/Processor/FieldMachine.h index c544fb96a..859c64a1f 100644 --- a/Processor/FieldMachine.h +++ b/Processor/FieldMachine.h @@ -9,6 +9,9 @@ #include "RingMachine.h" #include "HonestMajorityMachine.h" #include "Tools/ezOptionParser.h" +#include "Math/gfp.h" + +#include "OnlineOptions.hpp" template class U, class V = HonestMajorityMachine> class HonestMajorityFieldMachine @@ -36,7 +39,7 @@ class DishonestMajorityFieldMachine ez::ezOptionParser& opt, bool live_prep_default = true) { OnlineOptions& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv, 1000, live_prep_default, true}; + online_opts = {opt, argc, argv, T(), live_prep_default}; FieldMachine(argc, argv, opt, online_opts); } diff --git a/Processor/FieldMachine.hpp b/Processor/FieldMachine.hpp index f93517d98..89ec66e1c 100644 --- a/Processor/FieldMachine.hpp +++ b/Processor/FieldMachine.hpp @@ -10,6 +10,7 @@ #include "HonestMajorityMachine.h" #include "Math/gfp.h" #include "OnlineMachine.hpp" +#include "OnlineOptions.hpp" template class T, class V> @@ -24,7 +25,7 @@ template class T, class V> HonestMajorityFieldMachine::HonestMajorityFieldMachine(int argc, const char **argv, ez::ezOptionParser& opt, int nplayers) { - OnlineOptions online_opts(opt, argc, argv, 0, true, true); + OnlineOptions online_opts(opt, argc, argv, T()); FieldMachine(argc, argv, opt, online_opts, nplayers); } diff --git a/Processor/HonestMajorityMachine.cpp b/Processor/HonestMajorityMachine.cpp index 3a756bc8b..295ef5fa0 100644 --- a/Processor/HonestMajorityMachine.cpp +++ b/Processor/HonestMajorityMachine.cpp @@ -18,7 +18,6 @@ HonestMajorityMachine::HonestMajorityMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers) : OnlineMachine(argc, argv, opt, online_opts, nplayers) { - OnlineOptions::singleton = online_opts; opt.add( "", // Default. 0, // Required? @@ -29,6 +28,7 @@ HonestMajorityMachine::HonestMajorityMachine(int argc, const char** argv, "--unencrypted" // Flag token. ); online_opts.finalize(opt, argc, argv); + OnlineOptions::singleton = online_opts; use_encryption = not opt.get("-u")->isSet; diff --git a/Processor/Input.h b/Processor/Input.h index 9816c3578..98c6c83b0 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -14,6 +14,8 @@ using namespace std; #include "Tools/PointerVector.h" class ArithmeticProcessor; +template class SubProcessor; +template class Preprocessing; /** * Abstract base for input protocols @@ -25,6 +27,7 @@ class InputBase protected: Player* P; + int my_num; Buffer buffer; Timer timer; @@ -58,7 +61,7 @@ class InputBase /// Schedule input from other player virtual void add_other(int player, int n_bits = -1) = 0; /// Schedule input from all players - void add_from_all(const clear& input); + void add_from_all(const clear& input, int n_bits = -1); /// Send my inputs virtual void send_mine() = 0; diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 9272535bc..b9f7a77ab 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -19,6 +19,7 @@ template InputBase::InputBase(ArithmeticProcessor* proc) : P(0), values_input(0) { + my_num = -1; if (proc) buffer.setup(&proc->private_input, -1, proc->private_input_filename); } @@ -83,6 +84,7 @@ template void InputBase::reset_all(Player& P) { this->P = &P; + my_num = P.my_num(); os.resize(P.num_players()); for (int i = 0; i < P.num_players(); i++) reset(i); @@ -111,13 +113,13 @@ void Input::add_other(int player, int) } template -void InputBase::add_from_all(const clear& input) +void InputBase::add_from_all(const clear& input, int n_bits) { for (int i = 0; i < P->num_players(); i++) if (i == P->my_num()) - add_mine(input); + add_mine(input, n_bits); else - add_other(i); + add_other(i, n_bits); } template @@ -202,7 +204,7 @@ void Input::finalize_other(int player, T& target, template T InputBase::finalize(int player, int n_bits) { - if (player == P->my_num()) + if (player == my_num) return finalize_mine(); else { diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e516fdf37..e45a85045 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -1091,9 +1091,11 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.time(); break; case START: + Proc.machine.set_thread_comm(Proc.P.total_comm()); Proc.machine.start(n); break; case STOP: + Proc.machine.set_thread_comm(Proc.P.total_comm()); Proc.machine.stop(n); break; case RUN_TAPE: diff --git a/Processor/Machine.h b/Processor/Machine.h index 3f23dc9f9..331a9a22c 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -69,7 +69,6 @@ class Machine : public BaseMachine OnlineOptions opts; - NamedCommStats comm_stats; ExecutionStats stats; Machine(int my_number, Names& playerNames, const string& progname, diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 804dc51aa..d7d1a3ec3 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -142,6 +142,8 @@ template Machine::~Machine() { delete P; + for (auto& queue : queues) + delete queue; } template @@ -324,7 +326,7 @@ void Machine::run() { Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); proc_timer.start(); - timer[0].start(); + timer[0].start({}); // run main tape run_tape(0, 0, 0, N.num_players()); @@ -352,7 +354,6 @@ void Machine::run() queues[i]->schedule({}); pos.increase(queues[i]->result().pos); pthread_join(threads[i],NULL); - delete queues[i]; } finish_timer.stop(); @@ -372,6 +373,8 @@ void Machine::run() cerr << "Finish timer: " << finish_timer.elapsed() << endl; #endif + NamedCommStats comm_stats = total_comm(); + if (opts.verbose) { cerr << "Communication details " @@ -457,9 +460,12 @@ void Machine::run() } #ifndef INSECURE - Data_Files df(*this); - df.seekg(pos); - df.prune(); + if (not opts.file_prep_per_thread) + { + Data_Files df(*this); + df.seekg(pos); + df.prune(); + } #endif sint::LivePrep::teardown(); diff --git a/Processor/Memory.h b/Processor/Memory.h index 2c4a3d2e3..9ec02d2b8 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -43,8 +43,11 @@ class Memory template static void check_index(const vector& M, size_t i) { + (void) M, (void) i; +#ifdef NO_CHECK_INDEX if (i >= M.size()) throw overflow("memory", i, M.size()); +#endif } const typename T::clear& read_C(size_t i) const diff --git a/Processor/NoFilePrep.h b/Processor/NoFilePrep.h new file mode 100644 index 000000000..fbb44912e --- /dev/null +++ b/Processor/NoFilePrep.h @@ -0,0 +1,22 @@ +/* + * NoFilePrep.h + * + */ + +#ifndef PROCESSOR_NOFILEPREP_H_ +#define PROCESSOR_NOFILEPREP_H_ + +#include "Data_Files.h" + +template +class NoFilePrep : public Preprocessing +{ +public: + NoFilePrep(int, int, const string&, DataPositions& usage, int = -1) : + Preprocessing(usage) + { + throw runtime_error("don't call this"); + } +}; + +#endif /* PROCESSOR_NOFILEPREP_H_ */ diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index cffaded40..dcfafe553 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -71,7 +71,7 @@ void OfflineMachine::generate() auto my_usage = domain_usage[i]; Dtype dtype = Dtype(i); string filename = Sub_Data_Files::get_filename(playerNames, dtype, - T::clear::field_type() == DATA_GF2 ? 0 : -1); + 0); if (my_usage > 0) { ofstream out(filename, iostream::out | iostream::binary); @@ -106,7 +106,7 @@ void OfflineMachine::generate() for (int i = 0; i < P.num_players(); i++) { auto n_inputs = usage.inputs[i][T::clear::field_type()]; - string filename = Sub_Data_Files::get_input_filename(playerNames, i); + string filename = Sub_Data_Files::get_input_filename(playerNames, i, 0); if (n_inputs > 0) { ofstream out(filename, iostream::out | iostream::binary); @@ -137,7 +137,7 @@ void OfflineMachine::generate() int total = usage.edabits[{false, n_bits}] + usage.edabits[{true, n_bits}]; string filename = Sub_Data_Files::get_edabit_filename(playerNames, - n_bits); + n_bits, 0); if (total > 0) { ofstream out(filename, ios::binary); diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index cb25b4261..e98f1a3a1 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -279,7 +279,7 @@ void thread_info::Sub_Main_Func() printf("\tSignalling I have finished\n"); #endif wait_timer.start(); - queues->finished(job); + queues->finished(job, P.total_comm()); wait_timer.stop(); } } @@ -287,6 +287,11 @@ void thread_info::Sub_Main_Func() // final check Proc.check(); +#ifndef INSECURE + if (machine.opts.file_prep_per_thread) + Proc.DataF.prune(); +#endif + wait_timer.start(); queues->next(); wait_timer.stop(); @@ -314,16 +319,10 @@ void thread_info::Sub_Main_Func() #endif // wind down thread by thread - auto prep_stats = Proc.DataF.comm_stats(); - prep_stats += Proc.share_thread.DataF.comm_stats(); - prep_stats += Proc.Procp.bit_prep.comm_stats(); - for (auto& x : Proc.Procp.personal_bit_preps) - prep_stats += x->comm_stats(); machine.stats += Proc.stats; delete processor; - machine.comm_stats += P.comm_stats + prep_stats; - queues->finished(actual_usage); + queues->finished(actual_usage, P.total_comm()); delete MC2; delete MCp; diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 41308603b..2a5e090bd 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -29,6 +29,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; file_prep_per_thread = false; + trunc_error = 40; #ifdef VERBOSE verbose = true; #else @@ -326,6 +327,19 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, #endif lgp = max(lgp, gfp0::MAX_N_BITS); } + + set_trunc_error(opt); +} + +void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt) +{ + if (opt.get("-E")) + { + opt.get("-E")->getInt(trunc_error); +#ifdef VERBOSE + cerr << "Truncation error probability 2^-" << trunc_error << endl; +#endif + } } int OnlineOptions::prime_length() diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index de8f1e722..4b2fe4f8c 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -30,6 +30,7 @@ class OnlineOptions std::string cmd_private_output_file; bool verbose; bool file_prep_per_thread; + int trunc_error; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, @@ -37,10 +38,15 @@ class OnlineOptions OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size = 0, bool default_live_prep = true, bool variable_prime_length = false); + template + OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, T, + bool default_live_prep = true); ~OnlineOptions() {} void finalize(ez::ezOptionParser& opt, int argc, const char** argv); + void set_trunc_error(ez::ezOptionParser& opt); + int prime_length(); int prime_limbs(); diff --git a/Processor/OnlineOptions.hpp b/Processor/OnlineOptions.hpp new file mode 100644 index 000000000..8961853e5 --- /dev/null +++ b/Processor/OnlineOptions.hpp @@ -0,0 +1,30 @@ +/* + * OnlineOptions.hpp + * + */ + +#ifndef PROCESSOR_ONLINEOPTIONS_HPP_ +#define PROCESSOR_ONLINEOPTIONS_HPP_ + +#include "OnlineOptions.h" + +template +OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, + const char** argv, T, bool default_live_prep) : + OnlineOptions(opt, argc, argv, T::dishonest_majority ? 1000 : 0, + default_live_prep, T::clear::prime_field) +{ + if (T::has_trunc_pr) + opt.add( + to_string(trunc_error).c_str(), // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Probabilistic truncation error " + "(2^-x, default: 40)", // Help description. + "-E", // Flag token. + "--trunc-error" // Flag token. + ); +} + +#endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */ diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index 5c44b9087..4ca77daa1 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -40,21 +40,33 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir, + to_string(my_num) + get_suffix(thread_num); } -void PrepBase::print_left(const char* name, size_t n, const string& type_string) +void PrepBase::print_left(const char* name, size_t n, const string& type_string, + size_t used) { - if (n > 0) + if (n > 0 and OnlineOptions::singleton.verbose) cerr << "\t" << n << " " << name << " of " << type_string << " left" << endl; + + if (n > used / 10) + cerr << "Significant amount of unused " << name << " of " << type_string + << ". For more accurate benchmarks, " + << "consider reducing the batch size with -b." << endl; } void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict, - int n_bits) + int n_bits, size_t used) { - if (n > 0) + if (n > 0 and OnlineOptions::singleton.verbose) { cerr << "\t~" << n * n_batch; if (not strict) cerr << " loose"; cerr << " edaBits of size " << n_bits << " left" << endl; } + + if (n > used / 10) + cerr << "Significant amount of unused edaBits of size " << n_bits + << ". For more accurate benchmarks, " + << "consider reducing the batch size with -b " + << "or increasing the bucket size with -B." << endl; } diff --git a/Processor/PrepBase.h b/Processor/PrepBase.h index bedba6299..ccc2f4b40 100644 --- a/Processor/PrepBase.h +++ b/Processor/PrepBase.h @@ -24,8 +24,10 @@ class PrepBase static string get_edabit_filename(const string& prep_data_dir, int n_bits, int my_num, int thread_num = 0); - static void print_left(const char* name, size_t n, const string& type_string); - static void print_left_edabits(size_t n, size_t n_batch, bool strict, int n_bits); + static void print_left(const char* name, size_t n, + const string& type_string, size_t used); + static void print_left_edabits(size_t n, size_t n_batch, bool strict, + int n_bits, size_t used); }; #endif /* PROCESSOR_PREPBASE_H_ */ diff --git a/Processor/Processor.h b/Processor/Processor.h index d9141855c..a78058cd1 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -243,10 +243,6 @@ class Processor : public ArithmeticProcessor cint get_inverse2(unsigned m); - // Print the processor state - template - friend ostream& operator<<(ostream& s,const Processor& P); - private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 6206e27c2..caea1e678 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -28,8 +28,8 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, bit_prep(bit_usage) { DataF.set_proc(this); + protocol.init(DataF, MC); DataF.set_protocol(protocol); - protocol.init_mul(this); bit_usage.set_num_players(P.num_players()); personal_bit_preps.resize(P.num_players()); for (int i = 0; i < P.num_players(); i++) @@ -39,22 +39,12 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, template SubProcessor::~SubProcessor() { - protocol.check(); - for (size_t i = 0; i < personal_bit_preps.size(); i++) { auto& x = personal_bit_preps[i]; -#ifdef VERBOSE - if (x->data_sent()) - cerr << "Sent for personal bit preprocessing threads of player " << i << ": " << - x->data_sent() * 1e-6 << " MB" << endl; -#endif delete x; } #ifdef VERBOSE - if (bit_prep.data_sent()) - cerr << "Sent for global bit preprocessing threads: " << - bit_prep.data_sent() * 1e-6 << " MB" << endl; if (not bit_usage.empty()) { cerr << "Mixed-circuit preprocessing cost:" << endl; @@ -423,7 +413,7 @@ void SubProcessor::muls(const vector& reg, int size) int n = reg.size() / 3; SubProcessor& proc = *this; - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < n; i++) for (int j = 0; j < size; j++) { @@ -448,7 +438,7 @@ void SubProcessor::mulrs(const vector& reg) int n = reg.size() / 4; SubProcessor& proc = *this; - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < n; i++) for (int j = 0; j < reg[4 * i]; j++) { @@ -470,7 +460,7 @@ void SubProcessor::mulrs(const vector& reg) template void SubProcessor::dotprods(const vector& reg, int size) { - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < size; i++) { auto it = reg.begin(); @@ -512,7 +502,7 @@ void SubProcessor::matmuls(const vector& source, assert(B + dim[1] * dim[2] <= source.end()); assert(C + dim[0] * dim[2] <= S.end()); - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) for (int j = 0; j < dim[2]; j++) { @@ -536,7 +526,7 @@ void SubProcessor::matmulsm(const CheckVector& source, assert(C + dim[0] * dim[2] <= S.end()); assert(Proc); - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) { auto ii = Proc->get_Ci().at(dim[3] + i); @@ -562,7 +552,7 @@ void SubProcessor::matmulsm(const CheckVector& source, template void SubProcessor::conv2ds(const Instruction& instruction) { - protocol.init_dotprod(this); + protocol.init_dotprod(); auto& args = instruction.get_start(); int output_h = args[0], output_w = args[1]; int inputs_h = args[2], inputs_w = args[3]; @@ -670,30 +660,4 @@ typename sint::clear Processor::get_inverse2(unsigned m) return inverses2m[m]; } -template -ostream& operator<<(ostream& s,const Processor& P) -{ - s << "Processor State" << endl; - s << "Char 2 Registers" << endl; - s << "Val\tClearReg\tSharedReg" << endl; - for (int i=0; i(), live_prep_default}; RingMachine(argc, argv, opt, online_opts); } diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index add3f43cc..e422e0aa5 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -12,6 +12,7 @@ #include "Tools/ezOptionParser.h" #include "Math/gf2n.h" #include "OnlineMachine.hpp" +#include "OnlineOptions.hpp" template class U, template class V> @@ -25,7 +26,7 @@ template class U, template class V> HonestMajorityRingMachine::HonestMajorityRingMachine(int argc, const char** argv, ez::ezOptionParser& opt, int nplayers) { - OnlineOptions online_opts(opt, argc, argv); + OnlineOptions online_opts(opt, argc, argv, U<64>()); RingMachine(argc, argv, opt, online_opts, nplayers); } diff --git a/Processor/ThreadQueue.cpp b/Processor/ThreadQueue.cpp index 3f5b1c76d..6358e4a4a 100644 --- a/Processor/ThreadQueue.cpp +++ b/Processor/ThreadQueue.cpp @@ -27,6 +27,19 @@ void ThreadQueue::finished(const ThreadJob& job) out.push(job); } +void ThreadQueue::finished(const ThreadJob& job, const NamedCommStats& new_comm_stats) +{ + finished(job); + set_comm_stats(new_comm_stats); +} + +void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats) +{ + lock.lock(); + comm_stats = new_comm_stats; + lock.unlock(); +} + ThreadJob ThreadQueue::result() { auto res = out.pop(); @@ -38,3 +51,11 @@ ThreadJob ThreadQueue::result() lock.unlock(); return res; } + +NamedCommStats ThreadQueue::get_comm_stats() +{ + lock.lock(); + auto res = comm_stats; + lock.unlock(); + return res; +} diff --git a/Processor/ThreadQueue.h b/Processor/ThreadQueue.h index 2e994b3ad..f49722abb 100644 --- a/Processor/ThreadQueue.h +++ b/Processor/ThreadQueue.h @@ -13,6 +13,7 @@ class ThreadQueue WaitQueue in, out; Lock lock; int left; + NamedCommStats comm_stats; public: ThreadQueue() : @@ -28,7 +29,11 @@ class ThreadQueue void schedule(const ThreadJob& job); ThreadJob next(); void finished(const ThreadJob& job); + void finished(const ThreadJob& job, const NamedCommStats& comm_stats); ThreadJob result(); + + void set_comm_stats(const NamedCommStats& new_comm_stats); + NamedCommStats get_comm_stats(); }; #endif /* PROCESSOR_THREADQUEUE_H_ */ diff --git a/Processor/TruncPrTuple.h b/Processor/TruncPrTuple.h index 06a96845f..267acae48 100644 --- a/Processor/TruncPrTuple.h +++ b/Processor/TruncPrTuple.h @@ -10,26 +10,35 @@ #include using namespace std; +#include "OnlineOptions.h" + template class TruncPrTuple { public: + const static int n = 4; + int dest_base; int source_base; int k; int m; int n_shift; - TruncPrTuple(const vector& regs, size_t base) + TruncPrTuple(const vector& regs, size_t base) : + TruncPrTuple(regs.begin() + base) + { + } + + TruncPrTuple(vector::const_iterator it) { - dest_base = regs[base]; - source_base = regs[base + 1]; - k = regs[base + 2]; - m = regs[base + 3]; + dest_base = *it++; + source_base = *it++; + k = *it++; + m = *it++; n_shift = T::N_BITS - 1 - k; assert(m < k); assert(0 < k); - assert(m < T::N_BITS); + assert(m < T::n_bits()); } T upper(T mask) @@ -49,10 +58,17 @@ class TruncPrTupleWithGap : public TruncPrTuple { public: TruncPrTupleWithGap(const vector& regs, size_t base) : - TruncPrTuple(regs, base) + TruncPrTupleWithGap(regs.begin() + base) { } + TruncPrTupleWithGap(vector::const_iterator it) : + TruncPrTuple(it) + { + if (T::prime_field and small_gap()) + throw runtime_error("domain too small for chosen truncation error"); + } + T upper(T mask) { if (big_gap()) @@ -69,7 +85,12 @@ class TruncPrTupleWithGap : public TruncPrTuple bool big_gap() { - return this->k <= T::N_BITS - 40; + return this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error; + } + + bool small_gap() + { + return not big_gap(); } }; diff --git a/Programs/Source/keras_mnist_lenet_predict.mpc b/Programs/Source/keras_mnist_lenet_predict.mpc new file mode 100644 index 000000000..8b55de560 --- /dev/null +++ b/Programs/Source/keras_mnist_lenet_predict.mpc @@ -0,0 +1,44 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +# training_samples = MultiArray([60000, 28, 28], sfix) +# training_labels = MultiArray([60000, 10], sint) + +test_samples = MultiArray([1, 28, 28], sfix) +test_labels = MultiArray([1, 10], sint) + +# training_labels.input_from(0) +# training_samples.input_from(0) + +# test_labels.input_from(0) +# test_samples.input_from(0) + +from Compiler import ml +tf = ml + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +model.build(test_samples.sizes) + +start = 0 +for var in model.trainable_variables: + var.assign_all(0) +# start = var.read_from_file(start) + +guesses = model.predict(test_samples, batch_size=1) + +print_ln('guess %s', guesses.reveal_nested()[:3]) +print_ln('truth %s', test_labels.reveal_nested()[:3]) diff --git a/Protocols/Atlas.h b/Protocols/Atlas.h index 3dd34d173..c99d911a9 100644 --- a/Protocols/Atlas.h +++ b/Protocols/Atlas.h @@ -53,18 +53,13 @@ class Atlas : public ProtocolBase return shamir.get_n_relevant_players(); } - void init_mul(Preprocessing&, typename T::MAC_Check&) - { - init_mul(); - } - - void init_mul(SubProcessor* proc = 0); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void prepare(const typename T::open_type& product); void exchange(); T finalize_mul(int n = -1); - void init_dotprod(SubProcessor* proc); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp index bb6f18bfb..c3a919b3d 100644 --- a/Protocols/Atlas.hpp +++ b/Protocols/Atlas.hpp @@ -38,7 +38,7 @@ array Atlas::get_double_sharing() } template -void Atlas::init_mul(SubProcessor*) +void Atlas::init_mul() { oss.reset(); oss2.reset(); @@ -47,10 +47,9 @@ void Atlas::init_mul(SubProcessor*) } template -typename T::clear Atlas::prepare_mul(const T& x, const T& y, int) +void Atlas::prepare_mul(const T& x, const T& y, int) { prepare(x * y); - return {}; } template @@ -98,9 +97,9 @@ T Atlas::finalize_mul(int) } template -void Atlas::init_dotprod(SubProcessor* proc) +void Atlas::init_dotprod() { - init_mul(proc); + init_mul(); dotprod_share = 0; } diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index e0c24e49e..2d28127c7 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -38,14 +38,17 @@ class Beaver : public ProtocolBase Beaver(Player& P) : prep(0), MC(0), P(P) {} - Player& branch(); + typename T::Protocol branch(); - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init(Preprocessing& prep, typename T::MAC_Check& MC); + + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); + void check(); + void start_exchange(); void stop_exchange(); diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 639930059..dc9814870 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -13,30 +13,34 @@ #include template -Player& Beaver::branch() +typename T::Protocol Beaver::branch() { - return P; + typename T::Protocol res(P); + res.prep = prep; + res.MC = MC; + res.init_mul(); + return res; } template -void Beaver::init_mul(SubProcessor* proc) +void Beaver::init(Preprocessing& prep, typename T::MAC_Check& MC) { - assert(proc != 0); - init_mul(proc->DataF, proc->MC); + this->prep = &prep; + this->MC = &MC; } template -void Beaver::init_mul(Preprocessing& prep, typename T::MAC_Check& MC) +void Beaver::init_mul() { - this->prep = &prep; - this->MC = &MC; + assert(this->prep); + assert(this->MC); shares.clear(); opened.clear(); triples.clear(); } template -typename T::clear Beaver::prepare_mul(const T& x, const T& y, int n) +void Beaver::prepare_mul(const T& x, const T& y, int n) { (void) n; triples.push_back({{}}); @@ -44,7 +48,6 @@ typename T::clear Beaver::prepare_mul(const T& x, const T& y, int n) triple = prep->get_triple(n); shares.push_back(x - triple[0]); shares.push_back(y - triple[1]); - return 0; } template @@ -86,4 +89,11 @@ T Beaver::finalize_mul(int n) return tmp; } +template +void Beaver::check() +{ + assert(MC); + MC->Check(P); +} + #endif diff --git a/Protocols/BrainShare.h b/Protocols/BrainShare.h index 301ed9b0b..77f2e35f6 100644 --- a/Protocols/BrainShare.h +++ b/Protocols/BrainShare.h @@ -38,6 +38,8 @@ class BrainShare : public Rep3Share> const static int N_MASK_BITS = clear::N_BITS + S; const static int Z_BITS = 2 * (N_MASK_BITS) + 5 + S; + static const bool has_trunc_pr = false; + BrainShare() { } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 853782459..fb55f0cf4 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -9,6 +9,7 @@ #include "Replicated.h" #include "Math/Z2k.h" #include "Processor/Instruction.h" +#include "Processor/TruncPrTuple.h" #include @@ -75,15 +76,14 @@ class FakeProtocol : public ProtocolBase return P; } - void init_mul(SubProcessor*) + void init_mul() { results.clear(); } - typename T::clear prepare_mul(const T& x, const T& y, int = -1) + void prepare_mul(const T& x, const T& y, int = -1) { results.push_back(x * y); - return {}; } void exchange() @@ -95,9 +95,9 @@ class FakeProtocol : public ProtocolBase return results.next(); } - void init_dotprod(SubProcessor* proc) + void init_dotprod() { - init_mul(proc); + init_mul(); dot_prod = {}; } @@ -177,19 +177,22 @@ class FakeProtocol : public ProtocolBase res += overflow; } #else -#ifdef RISKY_TRUNCATION_IN_EMULATION - T r; - r.randomize(G); + if (TruncPrTupleWithGap(regs, i).big_gap()) + { + T r; + r.randomize(G); - if (source.negative()) - res = -T(((-source + r) >> n_shift) - (r >> n_shift)); + if (source.negative()) + res = -T(((-source + r) >> n_shift) - (r >> n_shift)); + else + res = ((source + r) >> n_shift) - (r >> n_shift); + } else - res = ((source + r) >> n_shift) - (r >> n_shift); -#else - T r; - r.randomize_part(G, n_shift); - res = (source + r) >> n_shift; -#endif + { + T r; + r.randomize_part(G, n_shift); + res = (source + r) >> n_shift; + } #endif } } diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index f36a7b754..569c136e6 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -32,6 +32,9 @@ class FakeShare : public T, public ShareInterface typedef GC::FakeSecret bit_type; + static const bool has_trunc_pr = true; + static const bool dishonest_majority = false; + static string type_short() { return "emul"; diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 1e8021467..8a00c793c 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -6,14 +6,14 @@ #ifndef PROTOCOLS_HEMI_H_ #define PROTOCOLS_HEMI_H_ -#include "SPDZ.h" +#include "Semi.h" #include "HemiMatrixPrep.h" /** * Matrix multiplication optimized with semi-homomorphic encryption */ template -class Hemi : public SPDZ +class Hemi : public Semi { map, HemiMatrixPrep*> matrix_preps; @@ -22,7 +22,7 @@ class Hemi : public SPDZ public: Hemi(Player& P) : - SPDZ(P) + Semi(P) { } ~Hemi(); diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index dc285c14c..e67b28a97 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -51,19 +51,20 @@ void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); - for (int i = 0; i < dim[0]; i++) + for (int k = 0; k < dim[1]; k++) { - auto ii = Proc->get_Ci().at(dim[3] + i); + for (int i = 0; i < dim[0]; i++) + { + auto kk = Proc->get_Ci().at(dim[4] + k); + auto ii = Proc->get_Ci().at(dim[3] + i); + A[{i, k}] = source.at(a + ii * dim[7] + kk); + } + for (int j = 0; j < dim[2]; j++) { auto jj = Proc->get_Ci().at(dim[6] + j); - for (int k = 0; k < dim[1]; k++) - { - auto kk = Proc->get_Ci().at(dim[4] + k); - auto ll = Proc->get_Ci().at(dim[5] + k); - A[{i, k}] = source.at(a + ii * dim[7] + kk); - B[{k, j}] = source.at(b + ll * dim[8] + jj); - } + auto ll = Proc->get_Ci().at(dim[5] + k); + B[{k, j}] = source.at(b + ll * dim[8] + jj); } } @@ -93,7 +94,8 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, subdim[2] = min(max_cols, B.n_cols - j); auto& prep = get_matrix_prep(subdim, processor); MatrixMC mc; - beaver.init_mul(prep, mc); + beaver.init(prep, mc); + beaver.init_mul(); beaver.prepare_mul(A.from(0, i, subdim.data()), B.from(i, j, subdim.data() + 1)); beaver.exchange(); diff --git a/Protocols/HighGearKeyGen.cpp b/Protocols/HighGearKeyGen.cpp index 2618feba6..1c8f9f74d 100644 --- a/Protocols/HighGearKeyGen.cpp +++ b/Protocols/HighGearKeyGen.cpp @@ -19,5 +19,5 @@ template<> void PartSetup::key_and_mac_generation(Player& P, MachineBase& machine, int, false_type) { - HighGearKeyGen<2, 2>(P, params).run(*this, machine); + HighGearKeyGen<0, 0>(P, params).run(*this, machine); } diff --git a/Protocols/LowGearKeyGen.cpp b/Protocols/LowGearKeyGen.cpp index 2b149bc0f..61829b368 100644 --- a/Protocols/LowGearKeyGen.cpp +++ b/Protocols/LowGearKeyGen.cpp @@ -19,5 +19,5 @@ template<> void PairwiseSetup::key_and_mac_generation(Player& P, PairwiseMachine& machine, int, false_type) { - LowGearKeyGen<2>(P, machine, params).run(*this); + LowGearKeyGen<0>(P, machine, params).run(*this); } diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index a59820404..9ff92fb0e 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -126,7 +126,7 @@ typename KeyGenProtocol::vector_type KeyGenProtocol::schur_product( vector_type res; assert(x.size() == y.size()); auto& protocol = proc->protocol; - protocol.init_mul(proc); + protocol.init_mul(); for (size_t i = 0; i < x.size(); i++) protocol.prepare_mul(x[i], y[i]); protocol.exchange(); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index db3f8dc71..85e9c84a1 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -50,11 +50,13 @@ Tree_MAC_Check::Tree_MAC_Check(const typename U::mac_key_type::Scalar& ai, in template Tree_MAC_Check::~Tree_MAC_Check() { +#ifndef NO_SECURITY_CHECK if (WaitingForCheck() > 0) { cerr << endl << "SECURITY BUG: insufficient checking" << endl; terminate(); } +#endif } template diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index ce34b64e4..96f2c8138 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -121,21 +121,6 @@ void shuffle_triple_generation(vector>& triples, Player& P, #endif } -template -void ShuffleSacrifice::shuffle(vector& check_triples, Player& P) -{ - int buffer_size = check_triples.size(); - - // shuffle - GlobalPRNG G(P); - for (int i = 0; i < buffer_size; i++) - { - int remaining = buffer_size - i; - int pos = G.get_uint(remaining); - swap(check_triples[i], check_triples[i + pos]); - } -} - template TripleShuffleSacrifice::TripleShuffleSacrifice() { @@ -251,32 +236,6 @@ void RingOnlyBitsFromSquaresPrep::buffer_bits() bits_from_square_in_ring(this->bits, this->buffer_size, &prep); } -template -void MaliciousRingPrep::buffer_edabits(bool strict, int n_bits, - ThreadQueues* queues) -{ - RunningTimer timer; -#ifndef NONPERSONAL_EDA - this->buffer_edabits_from_personal(strict, n_bits, queues); -#else - assert(this->proc != 0); - ShuffleSacrifice shuffle_sacrifice; - typedef typename T::bit_type::part_type bit_type; - vector> bits; - vector sums; - this->buffer_edabits_without_check(n_bits, sums, bits, - shuffle_sacrifice.minimum_n_inputs(), queues); - vector>& checked = this->edabits[{strict, n_bits}]; - shuffle_sacrifice.edabit_sacrifice(checked, sums, bits, - n_bits, *this->proc, strict, -1, queues); - if (strict) - this->sanitize(checked, n_bits, -1, queues); -#endif -#ifdef VERBOSE_EDA - cerr << "Total edaBit generation took " << timer.elapsed() << " seconds" << endl; -#endif -} - template void MalRepRingPrep::buffer_inputs(int player) { diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 1967994d0..f98e9797f 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -42,6 +42,7 @@ class MaliciousRep3Share : public Rep3Share typedef GC::MaliciousRepSecret bit_type; const static bool expensive = true; + static const bool has_trunc_pr = false; static string type_short() { diff --git a/Protocols/MaliciousRepPO.h b/Protocols/MaliciousRepPO.h index 62d4b1783..7b58970f7 100644 --- a/Protocols/MaliciousRepPO.h +++ b/Protocols/MaliciousRepPO.h @@ -11,17 +11,21 @@ template class MaliciousRepPO { +protected: Player& P; octetStream to_send; octetStream to_receive[2]; + PointerVector secrets; public: MaliciousRepPO(Player& P); + virtual ~MaliciousRepPO() {} void prepare_sending(const T& secret, int player); - void send(int player); - void receive(); + virtual void send(int player); + virtual void receive(); typename T::clear finalize(const T& secret); + typename T::clear finalize(); }; #endif /* PROTOCOLS_MALICIOUSREPPO_H_ */ diff --git a/Protocols/MaliciousRepPO.hpp b/Protocols/MaliciousRepPO.hpp index 38a3a274a..bae235647 100644 --- a/Protocols/MaliciousRepPO.hpp +++ b/Protocols/MaliciousRepPO.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_MALICIOUSREPPO_HPP_ +#define PROTOCOLS_MALICIOUSREPPO_HPP_ + #include "MaliciousRepPO.h" #include @@ -16,7 +19,10 @@ MaliciousRepPO::MaliciousRepPO(Player& P) : P(P) template void MaliciousRepPO::prepare_sending(const T& secret, int player) { - secret[2 - P.get_offset(player)].pack(to_send); + if (player == P.my_num()) + secrets.push_back(secret); + else + secret[2 - P.get_offset(player)].pack(to_send); } template @@ -24,7 +30,7 @@ void MaliciousRepPO::send(int player) { if (P.get_offset(player) == 2) P.send_to(player, to_send); - else + else if (P.my_num() != player) P.send_to(player, to_send.hash()); } @@ -42,3 +48,11 @@ typename T::clear MaliciousRepPO::finalize(const T& secret) { return secret.sum() + to_receive[0].template get(); } + +template +typename T::clear MaliciousRepPO::finalize() +{ + return finalize(secrets.next()); +} + +#endif diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 4be3fc63a..8ffbff7bd 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -61,8 +61,9 @@ void MaliciousBitOnlyRepPrep::set_protocol(typename T::Protocol& protocol) template void MaliciousBitOnlyRepPrep::init_honest(Player& P) { - honest_proc = new SubProcessor(honest_mc, honest_prep, - P); + if (not honest_proc) + honest_proc = new SubProcessor(honest_mc, + honest_prep, P); } template diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index ef61ec7b9..c9eb63cf6 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -6,6 +6,7 @@ #include "MamaPrep.h" #include "SemiMC.hpp" +#include "MalRepRingPrep.hpp" template MamaPrep::MamaPrep(SubProcessor* proc, DataPositions& usage) : diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 734453d31..5cfa82b84 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -21,8 +21,6 @@ class OTPrep : public virtual BitPrep ~OTPrep(); void set_protocol(typename T::Protocol& protocol); - - NamedCommStats comm_stats(); }; /** diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index cef603a25..1393bb464 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -40,8 +40,9 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) // make sure not to use Montgomery multiplication T::open_type::next::template init(false); + assert(not triple_generator); triple_generator = new typename T::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), + BaseMachine::fresh_ot_setup(proc->P), proc->P.N, -1, OnlineOptions::singleton.batch_size, 1, params, proc->MC.get_alphai(), &proc->P); @@ -121,13 +122,4 @@ T Preprocessing::get_random_from_inputs(int nplayers) return res; } -template -NamedCommStats OTPrep::comm_stats() -{ - auto res = BitPrep::comm_stats(); - if (triple_generator) - res += triple_generator->comm_stats(); - return res; -} - #endif diff --git a/Protocols/NoProtocol.h b/Protocols/NoProtocol.h index b99ce4e3e..d8259eb0f 100644 --- a/Protocols/NoProtocol.h +++ b/Protocols/NoProtocol.h @@ -45,12 +45,12 @@ class NoProtocol : public ProtocolBase } // prepare next round of multiplications - void init_mul(SubProcessor*) + void init_mul() { } // schedule multiplication - typename T::clear prepare_mul(const T&, const T&, int = -1) + void prepare_mul(const T&, const T&, int = -1) { throw runtime_error("no multiplication preparation"); } diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 70371744d..d4f2ab0fd 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -22,6 +22,8 @@ class PostSacriRepRingShare : public Rep3Share2 static const int BIT_LENGTH = K; static const int SECURITY = S; + static const bool has_trunc_pr = false; + typedef SignedZ2 clear; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; diff --git a/Protocols/PostSacrifice.h b/Protocols/PostSacrifice.h index 73ec766e4..54b178a74 100644 --- a/Protocols/PostSacrifice.h +++ b/Protocols/PostSacrifice.h @@ -30,8 +30,8 @@ class PostSacrifice : public ProtocolBase Player& branch(); - void init_mul(SubProcessor* proc); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange() { internal.exchange(); } T finalize_mul(int n = -1); diff --git a/Protocols/PostSacrifice.hpp b/Protocols/PostSacrifice.hpp index 4db3b73b4..0f72f4e81 100644 --- a/Protocols/PostSacrifice.hpp +++ b/Protocols/PostSacrifice.hpp @@ -25,9 +25,8 @@ Player& PostSacrifice::branch() } template -void PostSacrifice::init_mul(SubProcessor* proc) +void PostSacrifice::init_mul() { - (void) proc; // throw away unused operands operands.resize(results.size()); if ((int) results.size() >= OnlineOptions::singleton.batch_size) @@ -36,11 +35,11 @@ void PostSacrifice::init_mul(SubProcessor* proc) } template -typename T::clear PostSacrifice::prepare_mul(const T& x, const T& y, int n) +void PostSacrifice::prepare_mul(const T& x, const T& y, int n) { (void) n; operands.push_back({{x, y}}); - return internal.prepare_mul(x, y); + internal.prepare_mul(x, y); } template diff --git a/Protocols/ProtocolSet.h b/Protocols/ProtocolSet.h new file mode 100644 index 000000000..e6a8eb525 --- /dev/null +++ b/Protocols/ProtocolSet.h @@ -0,0 +1,107 @@ +/* + * ProtocolSet.h + * + */ + +#ifndef PROTOCOLS_PROTOCOLSET_H_ +#define PROTOCOLS_PROTOCOLSET_H_ + +#include "Processor/Processor.h" +#include "GC/ShareThread.h" +#include "ProtocolSetup.h" + +/** + * Input, multiplication, and output protocol instance + * for an arithmetic share type + */ +template +class ProtocolSet +{ + DataPositions usage; + +public: + typename T::MAC_Check output; + typename T::LivePrep preprocessing; + SubProcessor processor; + typename T::Protocol& protocol; + typename T::Input& input; + + ProtocolSet(Player& P, typename T::mac_key_type mac_key) : + usage(P.num_players()), output(mac_key), preprocessing(0, usage), processor( + output, preprocessing, P), protocol(processor.protocol), input( + processor.input) + { + } + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + ProtocolSet(Player& P, const ProtocolSetup& setup) : + ProtocolSet(P, setup.get_mac_key()) + { + } + + ~ProtocolSet() + { + } +}; + +/** + * Input, multiplication, and output protocol instance + * for a binary share type + */ +template +class BinaryProtocolSet +{ + DataPositions usage; + typename T::LivePrep prep; + GC::ShareThread thread; + +public: + typename T::MAC_Check& output; + typename T::Protocol& protocol; + typename T::Input input; + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + BinaryProtocolSet(Player& P, const BinaryProtocolSetup& setup) : + usage(P.num_players()), prep(usage), thread(prep, P, + setup.get_mac_key()), output(*thread.MC), protocol( + *thread.protocol), input(output, prep, P) + { + } +}; + +/** + * Input, multiplication, and output protocol instance + * for an arithmetic share type and the corresponding binary one + */ +template +class MixedProtocolSet +{ + ProtocolSet arithmetic; + +public: + BinaryProtocolSet binary; + + typename T::MAC_Check& output; + typename T::LivePrep& preprocessing; + typename T::Protocol& protocol; + typename T::Input& input; + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + MixedProtocolSet(Player& P, const MixedProtocolSetup& setup) : + arithmetic(P, setup), binary(P, setup.binary), output( + arithmetic.output), preprocessing(arithmetic.preprocessing), protocol( + arithmetic.protocol), input(arithmetic.input) + { + } +}; + +#endif /* PROTOCOLS_PROTOCOLSET_H_ */ diff --git a/Protocols/ProtocolSetup.h b/Protocols/ProtocolSetup.h new file mode 100644 index 000000000..b6d91b2bc --- /dev/null +++ b/Protocols/ProtocolSetup.h @@ -0,0 +1,95 @@ +/* + * ProtocolSetup.h + * + */ + +#ifndef PROTOCOLS_PROTOCOLSETUP_H_ +#define PROTOCOLS_PROTOCOLSETUP_H_ + +#include "Networking/Player.h" + +/** + * Global setup for an arithmetic share type + */ +template +class ProtocolSetup +{ + typename T::mac_key_type mac_key; + +public: + /** + * @param P communication instance (used for MAC generation if needed) + * @param prime_length length of prime if computing modulo a prime + * @param directory location to read MAC if needed + */ + ProtocolSetup(Player& P, int prime_length = 0, string directory = "") + { + // initialize fields + if (prime_length == 0) + prime_length = T::clear::MAX_N_BITS; + + T::clear::init_default(prime_length); + T::clear::next::init_default(prime_length, false); + + // must initialize MAC key for security of some protocols + T::read_or_generate_mac_key(directory, P, mac_key); + } + + ~ProtocolSetup() + { + T::LivePrep::teardown(); + } + + typename T::mac_key_type get_mac_key() const + { + return mac_key; + } +}; + +/** + * Global setup for a binary share type + */ +template +class BinaryProtocolSetup +{ + typename T::mac_key_type mac_key; + +public: + /** + * @param P communication instance (used for MAC generation if needed) + * @param directory location to read MAC if needed + */ + BinaryProtocolSetup(Player& P, string directory = "") + { + T::part_type::open_type::init_field(); + T::mac_key_type::init_field(); + T::part_type::read_or_generate_mac_key(directory, P, mac_key); + } + + typename T::mac_key_type get_mac_key() const + { + return mac_key; + } +}; + +/** + * Global setup for an arithmetic share type and the corresponding binary one + */ +template +class MixedProtocolSetup : public ProtocolSetup +{ +public: + BinaryProtocolSetup binary; + + /** + * @param P communication instance (used for MAC generation if needed) + * @param prime_length length of prime if computing modulo a prime + * @param directory location to read MAC if needed + */ + MixedProtocolSetup(Player& P, int prime_length = 0, string directory = "") : + ProtocolSetup(P, prime_length, directory), binary(P, directory) + { + } +}; + +#endif /* PROTOCOLS_PROTOCOLSETUP_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index d115b4c5c..e85065ac0 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -11,6 +11,7 @@ #include "Protocols/Replicated.h" #include "GC/ShareSecret.h" #include "ShareInterface.h" +#include "Processor/Instruction.h" template class ReplicatedPrep; template class ReplicatedRingPrep; @@ -67,6 +68,31 @@ class RepShare : public FixedVec, public ShareInterface assert(full); FixedVec::unpack(os); } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + shrsi(proc, inst, T::invertible); + } + + template + static void shrsi(SubProcessor&, const Instruction&, + true_type) + { + throw runtime_error("shrsi not implemented"); + } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst, + false_type) + { + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest = source >> inst.get_n(); + } + } }; template @@ -94,6 +120,7 @@ class Rep3Share : public RepShare const static bool dishonest_majority = false; const static bool expensive = false; const static bool variable_players = false; + static const bool has_trunc_pr = true; static string type_short() { diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index c7a494525..23f28cf9b 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -31,7 +31,6 @@ class Rep3Share2 : public Rep3Share> typedef GC::SemiHonestRepSecret bit_type; - static const bool has_trunc_pr = true; static const bool has_split = true; Rep3Share2() @@ -132,17 +131,6 @@ class Rep3Share2 : public Rep3Share> } } } - - template - static void shrsi(SubProcessor& proc, const Instruction& inst) - { - for (int i = 0; i < inst.get_size(); i++) - { - auto& dest = proc.get_S_ref(inst.get_r(0) + i); - auto& source = proc.get_S_ref(inst.get_r(1) + i); - dest = source >> inst.get_n(); - } - } }; #endif /* PROTOCOLS_REP3SHARE2K_H_ */ diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h index aa0fc7bce..6acfae421 100644 --- a/Protocols/Rep4.h +++ b/Protocols/Rep4.h @@ -60,6 +60,11 @@ class Rep4 : public ProtocolBase void trunc_pr(const vector& regs, int size, SubProcessor& proc, false_type); + template + T finalize_mul(int n_bits, true_type); + template + T finalize_mul(int n_bits, false_type); + public: prngs_type rep_prngs; Player& P; @@ -70,14 +75,13 @@ class Rep4 : public ProtocolBase Rep4 branch(); - void init_mul(SubProcessor* proc = 0); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); void check(); - void init_dotprod(SubProcessor* proc); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index e77b4e6f5..a2deab2be 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -59,7 +59,7 @@ Rep4 Rep4::branch() } template -void Rep4::init_mul(SubProcessor*) +void Rep4::init_mul() { for (auto& x : add_shares) x.clear(); @@ -70,12 +70,6 @@ void Rep4::init_mul(SubProcessor*) channels.resize(P.num_players(), vector(P.num_players(), false)); } -template -void Rep4::init_mul(Preprocessing&, typename T::MAC_Check&) -{ - init_mul(); -} - template void Rep4::reset_joint_input(int n_inputs) { @@ -194,13 +188,12 @@ int Rep4::get_player(int offset) } template -typename T::clear Rep4::prepare_mul(const T& x, const T& y, int n_bits) +void Rep4::prepare_mul(const T& x, const T& y, int n_bits) { auto a = get_addshares(x, y); for (int i = 0; i < 5; i++) add_shares[i].push_back(a[i]); bit_lengths.push_back(n_bits); - return {}; } template @@ -215,7 +208,7 @@ array Rep4::get_addshares(const T& x, const T& y) } template -void Rep4::init_dotprod(SubProcessor*) +void Rep4::init_dotprod() { init_mul(); dotprod_shares = {}; @@ -260,10 +253,27 @@ void Rep4::exchange() } template -T Rep4::finalize_mul(int) +T Rep4::finalize_mul(int n_bits) { this->counter++; - return results.next().res; + if (n_bits == -1) + return results.next().res; + else + return finalize_mul(n_bits, T::clear::binary); +} + +template +template +T Rep4::finalize_mul(int n_bits, true_type) +{ + return results.next().res.mask(n_bits); +} + +template +template +T Rep4::finalize_mul(int, false_type) +{ + throw runtime_error("bit-wise multiplication not supported"); } template diff --git a/Protocols/Rep4Prep.hpp b/Protocols/Rep4Prep.hpp index 17915e43d..e871e82c9 100644 --- a/Protocols/Rep4Prep.hpp +++ b/Protocols/Rep4Prep.hpp @@ -54,7 +54,7 @@ template void Rep4RingPrep::buffer_squares() { generate_squares(this->squares, OnlineOptions::singleton.batch_size, - this->protocol, this->proc); + this->protocol); } template diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 3de9bfabc..67527a208 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -76,10 +76,13 @@ class ProtocolBase /// Single multiplication T mul(const T& x, const T& y); + /// Initialize protocol if needed (repeated call possible) + virtual void init(Preprocessing&, typename T::MAC_Check&) {} + /// Initialize multiplication round - virtual void init_mul(SubProcessor* proc) = 0; + virtual void init_mul() = 0; /// Schedule multiplication of operand pair - virtual typename T::clear prepare_mul(const T& x, const T& y, int n = -1) = 0; + virtual void prepare_mul(const T& x, const T& y, int n = -1) = 0; /// Run multiplication protocol virtual void exchange() = 0; /// Get next multiplication result @@ -88,7 +91,7 @@ class ProtocolBase virtual void finalize_mult(T& res, int n = -1); /// Initialize dot product round - void init_dotprod(SubProcessor* proc) { init_mul(proc); } + void init_dotprod() { init_mul(); } /// Add operand pair to current dot product void prepare_dotprod(const T& x, const T& y) { prepare_mul(x, y); } /// Finish dot product @@ -132,6 +135,11 @@ class Replicated : public ReplicatedBase, public ProtocolBase PointerVector add_shares; typename T::clear dotprod_share; + template + void trunc_pr(const vector& regs, int size, U& proc, true_type); + template + void trunc_pr(const vector& regs, int size, U& proc, false_type); + public: typedef ReplicatedMC MAC_Check; typedef ReplicatedInput Input; @@ -149,17 +157,13 @@ class Replicated : public ReplicatedBase, public ProtocolBase share[my_num] = value; } - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - void init_mul(); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); void prepare_reshare(const typename T::clear& share, int n = -1); - void init_dotprod(SubProcessor*) { init_mul(); } void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 75dc785be..374ed89b1 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -11,12 +11,10 @@ #include "Processor/TruncPrTuple.h" #include "Tools/benchmarking.h" -#include "SemiShare.h" -#include "SemiMC.h" #include "ReplicatedInput.h" #include "Rep3Share2k.h" -#include "SemiMC.hpp" +#include "ReplicatedPO.hpp" #include "Math/Z2k.hpp" template @@ -99,7 +97,8 @@ void ProtocolBase::multiply(vector& products, BaseMachine::thread_num); #endif - init_mul(&proc); + init(proc.DataF, proc.MC); + init_mul(); for (int i = begin; i < end; i++) prepare_mul(multiplicands[i].first, multiplicands[i].second); exchange(); @@ -110,7 +109,7 @@ void ProtocolBase::multiply(vector& products, template T ProtocolBase::mul(const T& x, const T& y) { - init_mul(0); + init_mul(); prepare_mul(x, y); exchange(); return finalize_mul(); @@ -146,20 +145,6 @@ T ProtocolBase::get_random() return res; } -template -void Replicated::init_mul(SubProcessor* proc) -{ - (void) proc; - init_mul(); -} - -template -void Replicated::init_mul(Preprocessing& prep, typename T::MAC_Check& MC) -{ - (void) prep, (void) MC; - init_mul(); -} - template void Replicated::init_mul() { @@ -169,12 +154,11 @@ void Replicated::init_mul() } template -inline typename T::clear Replicated::prepare_mul(const T& x, +void Replicated::prepare_mul(const T& x, const T& y, int n) { typename T::value_type add_share = x.local_mul(y); prepare_reshare(add_share, n); - return add_share; } template @@ -276,109 +260,89 @@ void Replicated::randoms(T& res, int n_bits) res[i].randomize_part(shared_prngs[i], n_bits); } -template -void trunc_pr(const vector& regs, int size, - SubProcessor>& proc) +template +template +void Replicated::trunc_pr(const vector& regs, int size, U& proc, + false_type) { assert(regs.size() % 4 == 0); assert(proc.P.num_players() == 3); assert(proc.Proc != 0); - typedef SignedZ2 value_type; - typedef Rep3Share T; - bool generate = proc.P.my_num() == 2; + typedef typename T::clear value_type; + int gen_player = 2; + int comp_player = 1; + bool generate = P.my_num() == gen_player; + bool compute = P.my_num() == comp_player; + ArgList> infos(regs); + auto& S = proc.get_S(); + + octetStream cs; + ReplicatedInput input(P); + if (generate) { - octetStream os[2]; - for (size_t i = 0; i < regs.size(); i += 4) - { - TruncPrTuple info(regs, i); - for (int l = 0; l < size; l++) + SeededPRNG G; + for (auto info : infos) + for (int i = 0; i < size; i++) { - auto& res = proc.get_S_ref(regs[i] + l); - auto& G = proc.Proc->secure_prng; - auto mask = G.template get(); - auto unmask = info.upper(mask); - T shares[4]; - shares[0].randomize_to_sum(mask, G); - shares[1].randomize_to_sum(unmask, G); - shares[2].randomize_to_sum(info.msb(mask), G); - res.randomize(G); - shares[3] = res; - for (int i = 0; i < 2; i++) - { - for (int j = 0; j < 4; j++) - shares[j][i].pack(os[i]); - } + auto r = G.get(); + input.add_mine(info.upper(r)); + if (info.small_gap()) + input.add_mine(info.msb(r)); + (r + S[info.source_base + i][0]).pack(cs); } - } - for (int i = 0; i < 2; i++) - proc.P.send_to(i, os[i]); + P.send_to(comp_player, cs); } else + input.add_other(gen_player); + + if (compute) { - octetStream os; - proc.P.receive_player(2, os); - OffsetPlayer player(proc.P, 1 - 2 * proc.P.my_num()); - typedef SemiShare semi_type; - vector> to_open; - PointerVector> mask_shares[3]; - for (size_t i = 0; i < regs.size(); i += 4) - for (int l = 0; l < size; l++) + P.receive_player(gen_player, cs); + for (auto info : infos) + for (int i = 0; i < size; i++) { - SemiShare share; - auto& x = proc.get_S_ref(regs[i + 1] + l); - if (proc.P.my_num() == 0) - share = x.sum(); - else - share = x[0]; - for (auto& mask_share : mask_shares) - mask_share.push_back(os.get()); - to_open.push_back(share + mask_shares[0].next()); - auto& res = proc.get_S_ref(regs[i] + l); - auto& a = res[1 - proc.P.my_num()]; - a.unpack(os); + auto c = cs.get() + S[info.source_base + i].sum(); + input.add_mine(info.upper(c)); + if (info.small_gap()) + input.add_mine(info.msb(c)); } - PointerVector opened; - DirectSemiMC> MC; - MC.POpen_(opened, to_open, player); - os.reset_write_head(); - for (size_t i = 0; i < regs.size(); i += 4) + } + + input.add_other(comp_player); + input.exchange(); + init_mul(); + + for (auto info : infos) + for (int i = 0; i < size; i++) { - int k = regs[i + 2]; - int m = regs[i + 3]; - int n_shift = value_type::N_BITS - 1 - k; - assert(m < k); - assert(0 < k); - assert(m < value_type::N_BITS); - for (int l = 0; l < size; l++) + auto c_prime = input.finalize(comp_player); + auto r_prime = input.finalize(gen_player); + S[info.dest_base + i] = c_prime - r_prime; + + if (info.small_gap()) { - auto& res = proc.get_S_ref(regs[i] + l); - auto masked = opened.next() << n_shift; - auto shifted = (masked << 1) >> (n_shift + m + 1); - auto diff = SemiShare::constant(shifted, - player.my_num()) - mask_shares[1].next(); - auto msb = masked >> (value_type::N_BITS - 1); - auto bit_mask = mask_shares[2].next(); - auto overflow = (bit_mask - + SemiShare::constant(msb, player.my_num()) - - bit_mask * msb * 2); - auto res_share = diff + (overflow << (k - m)); - auto& a = res[1 - proc.P.my_num()]; - auto& b = res[proc.P.my_num()]; - b = res_share - a; - b.pack(os); + auto c_dprime = input.finalize(comp_player); + auto r_msb = input.finalize(gen_player); + S[info.dest_base + i] += ((r_msb + c_dprime) + << (info.k - info.m)); + prepare_mul(r_msb, c_dprime); } } - player.exchange(os); - for (size_t i = 0; i < regs.size(); i += 4) - for (int l = 0; l < size; l++) - proc.get_S_ref(regs[i] + l)[proc.P.my_num()] += - os.get(); - } + + exchange(); + + for (auto info : infos) + for (int i = 0; i < size; i++) + if (info.small_gap()) + S[info.dest_base + i] -= finalize_mul() + << (info.k - info.m + 1); } template -void trunc_pr(const vector& regs, int size, SubProcessor& proc) +template +void Replicated::trunc_pr(const vector& regs, int size, U& proc, + true_type) { (void) regs, (void) size, (void) proc; throw runtime_error("trunc_pr not implemented"); @@ -390,7 +354,7 @@ void Replicated::trunc_pr(const vector& regs, int size, U& proc) { this->trunc_rounds++; - ::trunc_pr(regs, size, proc); + trunc_pr(regs, size, proc, T::clear::characteristic_two); } #endif diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 7d62838a3..9bb3c30a3 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -72,9 +72,8 @@ class ReplicatedInput : public PrepLessInput PrepLessInput(proc), proc(proc), P(P), protocol(P) { assert(T::length == 2); - InputBase::P = &P; - InputBase::os.resize(P.num_players()); expect.resize(P.num_players()); + this->reset_all(P); } void reset(int player); diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 741d2c490..1cfac4a16 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -71,7 +71,7 @@ template inline void ReplicatedInput::finalize_other(int player, T& target, octetStream& o, int n_bits) { - int offset = player - P.my_num(); + int offset = player - this->my_num; if (offset == 1 or offset == -2) { typename T::value_type t; diff --git a/Protocols/ReplicatedPO.h b/Protocols/ReplicatedPO.h new file mode 100644 index 000000000..a533a5b1a --- /dev/null +++ b/Protocols/ReplicatedPO.h @@ -0,0 +1,24 @@ +/* + * ReplicatedPO.h + * + */ + +#ifndef PROTOCOLS_REPLICATEDPO_H_ +#define PROTOCOLS_REPLICATEDPO_H_ + +#include "MaliciousRepPO.h" + +template +class ReplicatedPO : public MaliciousRepPO +{ +public: + ReplicatedPO(Player& P) : + MaliciousRepPO(P) + { + } + + void send(int player); + void receive(); +}; + +#endif /* PROTOCOLS_REPLICATEDPO_H_ */ diff --git a/Protocols/ReplicatedPO.hpp b/Protocols/ReplicatedPO.hpp new file mode 100644 index 000000000..aecd85b3f --- /dev/null +++ b/Protocols/ReplicatedPO.hpp @@ -0,0 +1,21 @@ +/* + * ReplicatedPO.cpp + * + */ + +#include "ReplicatedPO.h" + +#include "MaliciousRepPO.hpp" + +template +void ReplicatedPO::send(int player) +{ + if (this->P.get_offset(player) == 2) + this->P.send_to(player, this->to_send); +} + +template +void ReplicatedPO::receive() +{ + this->P.receive_relative(1, this->to_receive[0]); +} diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 8c3ed3f13..8a30749c3 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -184,6 +184,15 @@ class RingPrep : public virtual BitPrep template void sanitize(vector>& edabits, int n_bits); + template + void buffer_personal_edabits_without_check_pre(int n_bits, + Player& P, typename T::Input& input, typename BT::Input& bit_input, + int input_player, int buffer_size); + template + void buffer_personal_edabits_without_check_post(int n_bits, + vector& sums, vector >& bits, typename T::Input& input, + typename BT::Input& bit_input, int input_player, int begin, int end); + public: RingPrep(SubProcessor* proc, DataPositions& usage); virtual ~RingPrep(); @@ -224,6 +233,13 @@ class RingPrep : public virtual BitPrep template class SemiHonestRingPrep : public virtual RingPrep { + template + void buffer_bits(false_type, false_type); + template + void buffer_bits(true_type, false_type); + template + void buffer_bits(false_type, true_type); + public: SemiHonestRingPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), @@ -232,7 +248,7 @@ class SemiHonestRingPrep : public virtual RingPrep } virtual ~SemiHonestRingPrep() {} - virtual void buffer_bits() { this->buffer_bits_without_check(); } + virtual void buffer_bits(); virtual void buffer_inputs(int player) { this->buffer_inputs_as_usual(player, this->proc); } @@ -358,11 +374,6 @@ template class ReplicatedPrep : public virtual ReplicatedRingPrep, public virtual SemiHonestRingPrep { - template - void buffer_bits(false_type); - template - void buffer_bits(true_type); - public: ReplicatedPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), @@ -384,7 +395,7 @@ class ReplicatedPrep : public virtual ReplicatedRingPrep, } void buffer_squares() { ReplicatedRingPrep::buffer_squares(); } - void buffer_bits(); + void buffer_bits() { SemiHonestRingPrep::buffer_bits(); } }; #endif /* PROTOCOLS_REPLICATEDPREP_H_ */ diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 2b8aa1604..916ee6b8f 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -56,24 +56,23 @@ BufferPrep::~BufferPrep() << " bit generation" << endl; #endif - if (OnlineOptions::singleton.verbose) - { - this->print_left("triples", triples.size() * T::default_length, - type_string); - -#define X(KIND) \ - this->print_left(#KIND, KIND.size(), type_string); - X(squares) - X(inverses) - X(bits) - X(dabits) + this->print_left("triples", triples.size() * T::default_length, type_string, + this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) + * T::default_length); + +#define X(KIND, TYPE) \ + this->print_left(#KIND, KIND.size(), type_string, \ + this->usage.files.at(T::clear::field_type()).at(TYPE)); + X(squares, DATA_SQUARE) + X(inverses, DATA_INVERSE) + X(bits, DATA_BIT) + X(dabits, DATA_DABIT) #undef X - for (auto& x : this->edabits) - { - this->print_left_edabits(x.second.size(), x.second[0].size(), - x.first.first, x.first.second); - } + for (auto& x : this->edabits) + { + this->print_left_edabits(x.second.size(), x.second[0].size(), + x.first.first, x.first.second, this->usage.edabits[x.first]); } } @@ -100,7 +99,9 @@ RingPrep::~RingPrep() template void BitPrep::set_protocol(typename T::Protocol& protocol) { - this->protocol = new typename T::Protocol(protocol.branch()); + if (not this->protocol) + this->protocol = new typename T::Protocol(protocol.branch()); + this->protocol->init_mul(); auto proc = this->proc; if (proc and proc->Proc) this->base_player = proc->Proc->thread_num; @@ -202,16 +203,16 @@ template void ReplicatedRingPrep::buffer_squares() { generate_squares(this->squares, this->buffer_size, - this->protocol, this->proc); + this->protocol); } template void generate_squares(vector>& squares, int n_squares, - U* protocol, SubProcessor* proc) + U* protocol) { assert(protocol != 0); squares.resize(n_squares); - protocol->init_mul(proc); + protocol->init_mul(); for (size_t i = 0; i < squares.size(); i++) { auto& square = squares[i]; @@ -289,7 +290,7 @@ void BufferPrep::get_two_no_count(Dtype dtype, T& a, T& b) template void XOR(vector& res, vector& x, vector& y, - typename T::Protocol& prot, SubProcessor* proc) + typename T::Protocol& prot) { assert(x.size() == y.size()); int buffer_size = x.size(); @@ -302,7 +303,7 @@ void XOR(vector& res, vector& x, vector& y, return; } - prot.init_mul(proc); + prot.init_mul(); for (int i = 0; i < buffer_size; i++) prot.prepare_mul(x[i], y[i]); prot.exchange(); @@ -337,13 +338,14 @@ void buffer_bits_from_squares(RingPrep& prep) template template -void ReplicatedPrep::buffer_bits(true_type) +void SemiHonestRingPrep::buffer_bits(true_type, false_type) { if (this->protocol->get_n_relevant_players() > 10 - or OnlineOptions::singleton.bits_from_squares) + or OnlineOptions::singleton.bits_from_squares + or T::dishonest_majority) buffer_bits_from_squares(*this); else - ReplicatedRingPrep::buffer_bits(); + this->buffer_bits_without_check(); } template @@ -409,10 +411,9 @@ void MaliciousRingPrep::buffer_personal_dabits_without_check( auto& P = this->proc->P; auto &party = GC::ShareThread::s(); typedef typename T::bit_type::part_type BT; - SubProcessor bit_proc(party.MC->get_part_MC(), + typename BT::Input bit_input(party.MC->get_part_MC(), this->proc->bit_prep, this->proc->P); typename T::Input input(*this->proc, this->proc->MC); - typename BT::Input bit_input(bit_proc, bit_proc.MC); input.reset_all(P); bit_input.reset_all(P); SeededPRNG G; @@ -454,10 +455,24 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, typename BT::Input bit_input(proc, proc.MC); input.reset_all(P); bit_input.reset_all(P); - SeededPRNG G; assert(begin % BT::default_length == 0); int buffer_size = end - begin; + buffer_personal_edabits_without_check_pre(n_bits, P, input, bit_input, + input_player, buffer_size); + input.exchange(); + bit_input.exchange(); + buffer_personal_edabits_without_check_post(n_bits, sums, bits, input, + bit_input, input_player, begin, end); +} + +template +template +void RingPrep::buffer_personal_edabits_without_check_pre(int n_bits, + Player& P, typename T::Input& input, typename BT::Input& bit_input, + int input_player, int buffer_size) +{ int n_chunks = DIV_CEIL(buffer_size, BT::default_length); + SeededPRNG G; if (input_player == P.my_num()) { for (int i = 0; i < n_chunks; i++) @@ -482,8 +497,16 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, for (int i = 0; i < BT::default_length; i++) input.add_other(input_player); } - input.exchange(); - bit_input.exchange(); +} + +template +template +void RingPrep::buffer_personal_edabits_without_check_post(int n_bits, + vector& sums, vector >& bits, typename T::Input& input, + typename BT::Input& bit_input, int input_player, int begin, int end) +{ + int buffer_size = end - begin; + int n_chunks = DIV_CEIL(buffer_size, BT::default_length); for (int i = 0; i < buffer_size; i++) sums[begin + i] = input.finalize(input_player); assert(bits.size() == size_t(n_bits)); @@ -600,18 +623,18 @@ void BitPrep::buffer_ring_bits_without_check(vector& bits, PRNG& G, assert(proc != 0); int n_relevant_players = protocol->get_n_relevant_players(); vector> player_bits; - auto stat = proc->P.comm_stats; + auto stat = proc->P.total_comm(); buffer_bits_from_players(player_bits, G, *proc, this->base_player, buffer_size, 1); auto& prot = *protocol; - XOR(bits, player_bits[0], player_bits[1], prot, proc); + XOR(bits, player_bits[0], player_bits[1], prot); for (int i = 2; i < n_relevant_players; i++) - XOR(bits, bits, player_bits[i], prot, proc); + XOR(bits, bits, player_bits[i], prot); this->base_player++; (void) stat; #ifdef VERBOSE_PREP cerr << "bit generation" << endl; - (proc->P.comm_stats - stat).print(true); + (proc->P.total_comm() - stat).print(true); #endif } @@ -730,9 +753,22 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, vector> player_ints(n_relevant, vector(buffer_size)); vector>> parts(n_relevant, vector>(n_bits, vector(buffer_size / dl))); + InScope in_scope(this->do_count, false); + assert(this->proc != 0); + auto& P = proc->P; + typename T::Input input(*this->proc, this->proc->MC); + typename BT::Input bit_input(bit_proc, bit_proc.MC); + input.reset_all(P); + bit_input.reset_all(P); + assert(begin % BT::default_length == 0); + for (int i = 0; i < n_relevant; i++) + buffer_personal_edabits_without_check_pre(n_bits, P, input, bit_input, + i, buffer_size); + input.exchange(); + bit_input.exchange(); for (int i = 0; i < n_relevant; i++) - buffer_personal_edabits_without_check<0>(n_bits, player_ints[i], parts[i], - bit_proc, i, 0, buffer_size); + buffer_personal_edabits_without_check_post(n_bits, player_ints[i], + parts[i], input, bit_input, i, 0, buffer_size); vector>> player_bits(n_bits, vector>(n_relevant)); for (int i = 0; i < n_bits; i++) @@ -754,7 +790,7 @@ template void RingPrep::buffer_edabits_without_check(int n_bits, vector>& edabits, int buffer_size) { - auto stat = this->proc->P.comm_stats; + auto stat = this->proc->P.total_comm(); typedef typename T::bit_type::part_type bit_type; vector> bits; vector sums; @@ -763,7 +799,7 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector>& (void) stat; #ifdef VERBOSE_PREP cerr << "edaBit generation" << endl; - (proc->P.comm_stats - stat).print(true); + (proc->P.total_comm() - stat).print(true); #endif } @@ -920,40 +956,38 @@ void RingPrep::sanitize(vector>& edabits, int n_bits) delete &MCB; } -template<> -inline -void SemiHonestRingPrep>::buffer_bits() -{ - assert(protocol != 0); - bits_from_random(bits, *protocol); -} - template -void bits_from_random(vector& bits, typename T::Protocol& protocol) +template +void SemiHonestRingPrep::buffer_bits(false_type, true_type) { - while (bits.size() < (size_t)OnlineOptions::singleton.batch_size) - { - Rep3Share share = protocol.get_random(); - for (int j = 0; j < gf2n::degree(); j++) + assert(this->protocol != 0); + if (not T::dishonest_majority and T::variable_players) + // Shamir + this->buffer_bits_without_check(); + else + while (this->bits.size() < (size_t) OnlineOptions::singleton.batch_size) { - bits.push_back(share & 1); - share >>= 1; + auto share = this->get_random(); + for (int j = 0; j < T::open_type::degree(); j++) + { + this->bits.push_back(share & 1); + share >>= 1; + } } - } } template template -void ReplicatedPrep::buffer_bits(false_type) +void SemiHonestRingPrep::buffer_bits(false_type, false_type) { - ReplicatedRingPrep::buffer_bits(); + this->buffer_bits_without_check(); } template -void ReplicatedPrep::buffer_bits() +void SemiHonestRingPrep::buffer_bits() { assert(this->protocol != 0); - buffer_bits<0>(T::clear::prime_field); + buffer_bits(T::clear::prime_field, T::clear::characteristic_two); } template diff --git a/Protocols/Semi2k.h b/Protocols/Semi.h similarity index 75% rename from Protocols/Semi2k.h rename to Protocols/Semi.h index 69cf63aad..e290ca0eb 100644 --- a/Protocols/Semi2k.h +++ b/Protocols/Semi.h @@ -3,8 +3,8 @@ * */ -#ifndef PROTOCOLS_SEMI2K_H_ -#define PROTOCOLS_SEMI2K_H_ +#ifndef PROTOCOLS_SEMI_H_ +#define PROTOCOLS_SEMI_H_ #include "SPDZ.h" #include "Processor/TruncPrTuple.h" @@ -13,12 +13,12 @@ * Dishonest-majority protocol for computation modulo a power of two */ template -class Semi2k : public SPDZ +class Semi : public SPDZ { SeededPRNG G; public: - Semi2k(Player& P) : + Semi(Player& P) : SPDZ(P) { } @@ -30,6 +30,19 @@ class Semi2k : public SPDZ void trunc_pr(const vector& regs, int size, SubProcessor& proc) + { + trunc_pr(regs, size, proc, T::clear::characteristic_two); + } + + template + void trunc_pr(const vector&, int, SubProcessor&, true_type) + { + throw not_implemented(); + } + + template + void trunc_pr(const vector& regs, int size, + SubProcessor& proc, false_type) { if (this->P.num_players() > 2) throw runtime_error("probabilistic truncation " @@ -60,4 +73,4 @@ class Semi2k : public SPDZ } }; -#endif /* PROTOCOLS_SEMI2K_H_ */ +#endif /* PROTOCOLS_SEMI_H_ */ diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index a9df48b4a..ee5e83202 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -7,7 +7,7 @@ #define PROTOCOLS_SEMI2KSHARE_H_ #include "SemiShare.h" -#include "Semi2k.h" +#include "Semi.h" #include "OT/Rectangle.h" #include "GC/SemiSecret.h" #include "GC/square64.h" @@ -27,7 +27,7 @@ class Semi2kShare : public SemiShare> typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef Semi2k Protocol; + typedef Semi Protocol; typedef SemiPrep2k LivePrep; typedef Semi2kShare prep_type; @@ -35,8 +35,6 @@ class Semi2kShare : public SemiShare> typedef OTTripleGenerator TripleGenerator; typedef Z2kSquare Rectangle; - typedef GC::SemiSecret bit_type; - static const bool has_split = true; Semi2kShare() diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index ed044c461..c2dd90858 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -7,6 +7,7 @@ #define PROTOCOLS_SEMISHARE_H_ #include "Protocols/Beaver.h" +#include "Protocols/Semi.h" #include "Processor/DummyProtocol.h" #include "ShareInterface.h" @@ -16,7 +17,7 @@ using namespace std; template class Input; template class SemiMC; template class DirectSemiMC; -template class SPDZ; +template class Semi; template class SemiPrep; template class SemiInput; template class PrivateOutput; @@ -59,7 +60,7 @@ class SemiShare : public T, public ShareInterface typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; + typedef Semi Protocol; typedef SemiPrep LivePrep; typedef LivePrep TriplePrep; @@ -69,12 +70,15 @@ class SemiShare : public T, public ShareInterface typedef T sacri_type; typedef typename T::Square Rectangle; +#ifndef NO_MIXED_CIRCUITS typedef GC::SemiSecret bit_type; +#endif const static bool needs_ot = true; const static bool dishonest_majority = true; const static bool variable_players = true; const static bool expensive = false; + static const bool has_trunc_pr = true; static string type_short() { return "D" + string(1, T::type_char()); } diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 3d2bf469b..f722886eb 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -62,20 +62,8 @@ class Shamir : public ProtocolBase void reset(); void init_mul(); - void init_mul(SubProcessor* proc); - template - void init_mul(V*) - { - init_mul(); - } - template - void init_mul(const V&, const W&) - { - init_mul(); - } - - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); void start_exchange(); @@ -85,7 +73,7 @@ class Shamir : public ProtocolBase T finalize(int n_input_players); - void init_dotprod(SubProcessor* proc = 0); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index d387f3b47..9fe10bdea 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -80,13 +80,6 @@ void Shamir::reset() resharing->reset(i); } -template -void Shamir::init_mul(SubProcessor* proc) -{ - (void) proc; - init_mul(); -} - template void Shamir::init_mul() { @@ -96,13 +89,12 @@ void Shamir::init_mul() } template -typename T::clear Shamir::prepare_mul(const T& x, const T& y, int n) +void Shamir::prepare_mul(const T& x, const T& y, int n) { (void) n; auto add_share = x * y * rec_factor; if (P.my_num() < n_mul_players) resharing->add_mine(add_share); - return {}; } template @@ -157,9 +149,9 @@ T Shamir::finalize(int n_relevant_players) } template -void Shamir::init_dotprod(SubProcessor* proc) +void Shamir::init_dotprod() { - init_mul(proc); + init_mul(); dotprod_share = 0; } diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index fe509321c..81e859319 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -10,7 +10,6 @@ #include "Tools/PointerVector.h" #include "GC/BitAdder.h" -#include "MalRepRingPrep.hpp" #include "LimitedPrep.hpp" inline @@ -25,6 +24,21 @@ ShuffleSacrifice::ShuffleSacrifice(int B, int C) : { } +template +void ShuffleSacrifice::shuffle(vector& check_triples, Player& P) +{ + int buffer_size = check_triples.size(); + + // shuffle + GlobalPRNG G(P); + for (int i = 0; i < buffer_size; i++) + { + int remaining = buffer_size - i; + int pos = G.get_uint(remaining); + swap(check_triples[i], check_triples[i + pos]); + } +} + template void TripleShuffleSacrifice::triple_combine(vector >& triples, vector >& to_combine, Player& P, diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 33883c66f..03a91ff25 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -26,7 +26,6 @@ class Spdz2kPrep : public virtual MaliciousRingPrep, MascotTriplePrep* bit_prep; SubProcessor* bit_proc; typename BitShare::MAC_Check* bit_MC; - typename BitShare::Protocol* bit_protocol; public: Spdz2kPrep(SubProcessor* proc, DataPositions& usage); @@ -41,8 +40,6 @@ class Spdz2kPrep : public virtual MaliciousRingPrep, #ifdef SPDZ2K_BIT void get_dabit(T& a, GC::TinySecret& b); #endif - - NamedCommStats comm_stats(); }; #endif /* PROTOCOLS_SPDZ2KPREP_H_ */ diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index f5c9cdce6..815277614 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -25,7 +25,6 @@ Spdz2kPrep::Spdz2kPrep(SubProcessor* proc, DataPositions& usage) : bit_MC = 0; bit_proc = 0; bit_prep = 0; - bit_protocol = 0; } template @@ -36,7 +35,6 @@ Spdz2kPrep::~Spdz2kPrep() delete bit_prep; delete bit_proc; delete bit_MC; - delete bit_protocol; } } @@ -50,10 +48,8 @@ void Spdz2kPrep::set_protocol(typename T::Protocol& protocol) // just dummies bit_pos = DataPositions(proc->P.num_players()); bit_prep = new MascotTriplePrep(bit_proc, bit_pos); - bit_proc = new SubProcessor(*bit_MC, *bit_prep, proc->P); bit_prep->params.amplify = false; - bit_protocol = new typename BitShare::Protocol(proc->P); - bit_prep->set_protocol(*bit_protocol); + bit_proc = new SubProcessor(*bit_MC, *bit_prep, proc->P); bit_MC->set_prep(*bit_prep); this->proc->MC.set_prep(*this); } @@ -65,7 +61,7 @@ void MaliciousRingPrep::buffer_bits() RingPrep::buffer_bits_without_check(); assert(this->protocol != 0); auto& protocol = *this->protocol; - protocol.init_dotprod(this->proc); + protocol.init_dotprod(); auto one = T::constant(1, protocol.P.my_num(), this->proc->MC.get_alphai()); GlobalPRNG G(protocol.P); for (auto& bit : this->bits) @@ -238,12 +234,29 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, } template -NamedCommStats Spdz2kPrep::comm_stats() +void MaliciousRingPrep::buffer_edabits(bool strict, int n_bits, + ThreadQueues* queues) { - auto res = OTPrep::comm_stats(); - if (bit_prep) - res += bit_prep->comm_stats(); - return res; + RunningTimer timer; +#ifndef NONPERSONAL_EDA + this->buffer_edabits_from_personal(strict, n_bits, queues); +#else + assert(this->proc != 0); + ShuffleSacrifice shuffle_sacrifice; + typedef typename T::bit_type::part_type bit_type; + vector> bits; + vector sums; + this->buffer_edabits_without_check(n_bits, sums, bits, + shuffle_sacrifice.minimum_n_inputs(), queues); + vector>& checked = this->edabits[{strict, n_bits}]; + shuffle_sacrifice.edabit_sacrifice(checked, sums, bits, + n_bits, *this->proc, strict, -1, queues); + if (strict) + this->sanitize(checked, n_bits, -1, queues); +#endif +#ifdef VERBOSE_EDA + cerr << "Total edaBit generation took " << timer.elapsed() << " seconds" << endl; +#endif } #endif diff --git a/Protocols/SpdzWise.h b/Protocols/SpdzWise.h index afbf2c850..c12b4f5fe 100644 --- a/Protocols/SpdzWise.h +++ b/Protocols/SpdzWise.h @@ -38,22 +38,23 @@ class SpdzWise : public ProtocolBase SpdzWise(Player& P); virtual ~SpdzWise(); - Player& branch(); + typename T::Protocol branch(); - void init(SubProcessor* proc); + void init(Preprocessing&, typename T::MAC_Check& MC); - void init_mul(SubProcessor* proc); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); - void init_dotprod(SubProcessor*); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); void add_to_check(const T& x); void check(); + void maybe_check(); int get_n_relevant_players() { return internal.get_n_relevant_players(); } diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index 40f3cee71..2ea08ba46 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -19,34 +19,40 @@ SpdzWise::~SpdzWise() } template -Player& SpdzWise::branch() +typename T::Protocol SpdzWise::branch() { - return P; + typename T::Protocol res(P); + res.mac_key = mac_key; + return res; +} + +template +void SpdzWise::init(Preprocessing&, typename T::MAC_Check& MC) +{ + mac_key = MC.get_alphai(); } template -void SpdzWise::init(SubProcessor* proc) +void SpdzWise::maybe_check() { - assert(proc != 0); - mac_key = proc->MC.get_alphai(); + assert(not mac_key.is_zero()); if ((int) results.size() >= OnlineOptions::singleton.batch_size) check(); } template -void SpdzWise::init_mul(SubProcessor* proc) +void SpdzWise::init_mul() { - init(proc); + maybe_check(); internal.init_mul(); internal2.init_mul(); } template -typename T::clear SpdzWise::prepare_mul(const T& x, const T& y, int) +void SpdzWise::prepare_mul(const T& x, const T& y, int) { internal.prepare_mul(x.get_share(), y.get_share()); internal.prepare_mul(x.get_mac(), y.get_share()); - return {}; } template @@ -67,9 +73,9 @@ void SpdzWise::exchange() } template -void SpdzWise::init_dotprod(SubProcessor* proc) +void SpdzWise::init_dotprod() { - init(proc); + maybe_check(); internal.init_dotprod(); internal2.init_dotprod(); } diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp index ef7f549bf..e0d508e51 100644 --- a/Protocols/SpdzWiseInput.hpp +++ b/Protocols/SpdzWiseInput.hpp @@ -12,6 +12,7 @@ SpdzWiseInput::SpdzWiseInput(SubProcessor* proc, Player& P) : { assert(proc != 0); mac_key = proc->MC.get_alphai(); + checker.init(proc->DataF, proc->MC); } template @@ -76,7 +77,7 @@ void SpdzWiseInput::exchange() shares[i][j].set_mac(honest_mult.finalize_mul()); checker.results.push_back(shares[i][j]); } - checker.init(proc); + checker.maybe_check(); } template diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index f88e97d64..9cb86017a 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -9,19 +9,21 @@ #include "MaliciousShamirShare.h" #include "SquarePrep.h" #include "Math/gfp.h" +#include "ProtocolSet.h" #include "ReplicatedPrep.hpp" #include "Spdz2kPrep.hpp" #include "ShamirMC.hpp" #include "MaliciousRepPO.hpp" #include "MaliciousShamirPO.hpp" +#include "GC/RepPrep.hpp" template void SpdzWisePrep::buffer_triples() { assert(this->protocol != 0); assert(this->proc != 0); - this->protocol->init_mul(this->proc); + this->protocol->init_mul(); generate_triples_initialized(this->triples, OnlineOptions::singleton.batch_size, this->protocol); } @@ -38,8 +40,11 @@ void SpdzWisePrep>>::buffer_bits() { typedef MaliciousRep3Share part_type; vector bits; - typename part_type::Honest::Protocol protocol(this->protocol->P); - bits_from_random(bits, protocol); + ProtocolSet set(this->proc->P, {}); + auto& protocol = set.protocol; + auto& prep = set.preprocessing; + for (int i = 0; i < buffer_size; i++) + bits.push_back(prep.get_bit()); protocol.init_mul(); for (auto& bit : bits) protocol.prepare_mul(bit, this->proc->MC.get_alphai()); @@ -99,7 +104,7 @@ void SpdzWisePrep::buffer_inputs(int player) vector rs(OnlineOptions::singleton.batch_size); auto& P = this->proc->P; this->inputs.resize(P.num_players()); - this->protocol->init_mul(this->proc); + this->protocol->init_mul(); for (auto& r : rs) { r = this->protocol->get_random(); diff --git a/Protocols/SpdzWiseRing.hpp b/Protocols/SpdzWiseRing.hpp index 30904c386..36e638d14 100644 --- a/Protocols/SpdzWiseRing.hpp +++ b/Protocols/SpdzWiseRing.hpp @@ -36,7 +36,7 @@ void SpdzWiseRing::zero_check(check_type t) while(bits.size() > 1) { auto& protocol = zero_proc.protocol; - protocol.init_mul(&zero_proc); + protocol.init_mul(); for (int i = bits.size() - 2; i >= 0; i -= 2) protocol.prepare_mul(bits[i], bits[i + 1]); protocol.exchange(); diff --git a/Protocols/SquarePrep.h b/Protocols/SquarePrep.h index fcdc2c239..be0913b37 100644 --- a/Protocols/SquarePrep.h +++ b/Protocols/SquarePrep.h @@ -10,7 +10,7 @@ template void generate_squares(vector>& squares, int n_squares, - U* protocol, SubProcessor* proc); + U* protocol); template class SquarePrep : public BufferPrep @@ -22,8 +22,8 @@ class SquarePrep : public BufferPrep void buffer_squares() { - generate_squares(this->squares, this->buffer_size, &this->proc->protocol, - this->proc); + generate_squares(this->squares, this->buffer_size, + &this->proc->protocol); } public: diff --git a/README.md b/README.md index daa658a5f..bd1075121 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ The following table lists all protocols that are fully supported. | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | See [this paper](https://eprint.iacr.org/2020/300) for an explanation -of the various security models and high-level introduction to +of the various security models and a high-level introduction to multi-party computation. ##### Finding the most efficient protocol @@ -131,8 +131,8 @@ there are a few things to consider: dot products. - Fixed-point multiplication: Three- and four-party replicated secret - sharing modulo a power of two allow a special probabilistic - truncation protocol (see [Dalskov et + sharing as well semi-honest full-threshold protocols allow a special + probabilistic truncation protocol (see [Dalskov et al.](https://eprint.iacr.org/2019/131) and [Dalskov et al.](https://eprint.iacr.org/2020/1330)). You can activate it by adding `program.use_trunc_pr = True` at the beginning of your diff --git a/Scripts/decompile.py b/Scripts/decompile.py new file mode 100755 index 000000000..0142ba69a --- /dev/null +++ b/Scripts/decompile.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +import sys, os + +sys.path.append('.') + +from Compiler.instructions_base import Instruction +from Compiler.program import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +for tapename in Program.read_tapes(sys.argv[1]): + with open('Programs/Bytecode/%s.asm' % tapename, 'w') as out: + for i, inst in enumerate(Tape.read_instructions(tapename)): + print(inst, '#', i, file=out) diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py new file mode 100755 index 000000000..15959ee68 --- /dev/null +++ b/Scripts/memory-usage.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +import sys, os +import collections + +sys.path.append('.') + +from Compiler.program import * +from Compiler.instructions_base import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +res = collections.defaultdict(lambda: 0) +m = 0 + +for tapename in Program.read_tapes(sys.argv[1]): + for inst in Tape.read_instructions(tapename): + t = inst.type + if issubclass(t, DirectMemoryInstruction): + res[t.arg_format[0]] = max(inst.args[1].i + inst.size, + res[t.arg_format[0]]) + for arg in inst.args: + if isinstance(arg, RegisterArgFormat): + m = max(m, arg.i + inst.size) + +print (res) +print (m) + diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 3c0891e61..7e5e6d449 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -34,39 +34,26 @@ run_player() { if ! test -e $SPDZROOT/logs; then mkdir $SPDZROOT/logs fi - if [[ $bin = Player-Online.x || $bin =~ 'party.x' ]]; then - params="$prog $* -pn $port -h localhost" - if [[ ! ($bin =~ 'rep' || $bin =~ 'brain' || $bin =~ 'yao') ]]; then - params="$params -N $players" - fi - else - params="$port localhost $prog $*" + params="$prog $* -pn $port -h localhost" + if $SPDZROOT/$bin 2>&1 | grep -q '^-N,'; then + params="$params -N $players" fi - rem=$(($players - 2)) if test "$prog"; then log_prefix=$prog- fi - for i in $(seq 0 $rem); do + set -o pipefail + for i in $(seq 0 $[players-1]); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params log=$SPDZROOT/logs/$log_prefix$i $prefix $SPDZROOT/$bin $i $params 2>&1 | { if test $i = 0; then tee $log; else cat > $log; fi; } & + codes[$i]=$! + done + for i in $(seq 0 $[players-1]); do + wait ${codes[$i]} || return 1 done - last_player=$(($players - 1)) - i=$last_player - >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params - $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$log_prefix$last_player 2>&1 || return 1 - wait } -sleep 0.5 - -#mkdir /dev/shm/Player-Data - players=${PLAYERS:-2} SPDZROOT=${SPDZROOT:-.} - -#. Scripts/setup.sh - -mkdir logs 2> /dev/null diff --git a/Scripts/test_streaming.sh b/Scripts/test_streaming.sh index 0ff2fb336..62a493084 100755 --- a/Scripts/test_streaming.sh +++ b/Scripts/test_streaming.sh @@ -15,3 +15,7 @@ done ./stream-fake-mascot-triples.x & Scripts/mascot.sh test_thread_mul -f || exit 1 + +./stream-fake-mascot-triples.x & + +Scripts/mascot.sh test_thread_mul -f || exit 1 diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index 5dd4f45db..ed6c01441 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -27,7 +27,8 @@ if test "$flags"; then cpu=amd64 fi - cp -av bin/`uname`-$cpu/* . + cp -av bin/`uname`-$cpu/* . || { echo This only works with a release downloaded from https://github.com/data61/MP-SPDZ/releases 1>&2; exit 1; } fi mkdir Player-Data 2> /dev/null +exit 0 diff --git a/Tools/BitVector.cpp b/Tools/BitVector.cpp index 4ef3406f0..567e57885 100644 --- a/Tools/BitVector.cpp +++ b/Tools/BitVector.cpp @@ -9,6 +9,15 @@ #include #include +void BitVector::assign(const BitVector& K) +{ + if (nbits != K.nbits) + { + resize(K.nbits); + } + memcpy(bytes, K.bytes, nbytes); +} + void BitVector::resize_zero(size_t new_nbits) { size_t old_nbytes = nbytes; diff --git a/Tools/BitVector.h b/Tools/BitVector.h index 055610519..54d9ed109 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -33,14 +33,7 @@ class BitVector public: - void assign(const BitVector& K) - { - if (nbits != K.nbits) - { - resize(K.nbits); - } - memcpy(bytes, K.bytes, nbytes); - } + void assign(const BitVector& K); void assign_bytes(char* new_bytes, int len) { resize(len*8); diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index c669081f8..9dd15804c 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -26,7 +26,7 @@ void BufferBase::setup(ifstream* f, int length, const string& filename, bool BufferBase::is_pipe() { struct stat buf; - if (stat(filename.c_str(), &buf)) + if (stat(filename.c_str(), &buf) == 0) return S_ISFIFO(buf.st_mode); else return false; @@ -113,6 +113,17 @@ void BufferBase::prune() rename(tmp_name.c_str(), filename.c_str()); file->open(filename.c_str(), ios::in | ios::binary); } +#ifdef VERBOSE + else + { + cerr << "Not pruning " << filename << " because it's "; + if (file) + cerr << "closed"; + else + cerr << "unused"; + cerr << endl; + } +#endif } void BufferBase::purge() diff --git a/Tools/Bundle.h b/Tools/Bundle.h index ed4b982e3..7859e3e4c 100644 --- a/Tools/Bundle.h +++ b/Tools/Bundle.h @@ -31,7 +31,7 @@ class Bundle : public vector { } - void compare(Player& P) + void compare(PlayerBase& P) { P.unchecked_broadcast(*this); for (auto& os : *this) diff --git a/Tools/TimerWithComm.cpp b/Tools/TimerWithComm.cpp new file mode 100644 index 000000000..2a5e8e12a --- /dev/null +++ b/Tools/TimerWithComm.cpp @@ -0,0 +1,23 @@ +/* + * TimerWithComm.cpp + * + */ + +#include "TimerWithComm.h" + +void TimerWithComm::start(const NamedCommStats& stats) +{ + Timer::start(); + last_stats = stats; +} + +void TimerWithComm::stop(const NamedCommStats& stats) +{ + Timer::stop(); + total_stats += stats - last_stats; +} + +double TimerWithComm::mb_sent() +{ + return total_stats.sent * 1e-6; +} diff --git a/Tools/TimerWithComm.h b/Tools/TimerWithComm.h new file mode 100644 index 000000000..2f3976a20 --- /dev/null +++ b/Tools/TimerWithComm.h @@ -0,0 +1,23 @@ +/* + * TimerWithComm.h + * + */ + +#ifndef TOOLS_TIMERWITHCOMM_H_ +#define TOOLS_TIMERWITHCOMM_H_ + +#include "time-func.h" +#include "Networking/Player.h" + +class TimerWithComm : public Timer +{ + NamedCommStats total_stats, last_stats; + +public: + void start(const NamedCommStats& stats = {}); + void stop(const NamedCommStats& stats = {}); + + double mb_sent(); +}; + +#endif /* TOOLS_TIMERWITHCOMM_H_ */ diff --git a/Tools/benchmarking.cpp b/Tools/benchmarking.cpp new file mode 100644 index 000000000..e956f15ec --- /dev/null +++ b/Tools/benchmarking.cpp @@ -0,0 +1,15 @@ +/* + * benchmarking.cpp + * + */ + +#include "benchmarking.h" + +void insecure_fake() +{ +#if defined(INSECURE) or defined(INSECURE_FAKE) + cerr << "WARNING: insecure preprocessing" << endl; +#else + insecure("preprocessing"); +#endif +} diff --git a/Tools/benchmarking.h b/Tools/benchmarking.h index 0ca65b761..13fa9c365 100644 --- a/Tools/benchmarking.h +++ b/Tools/benchmarking.h @@ -8,6 +8,7 @@ #include #include +#include using namespace std; // call before insecure benchmarking functionality @@ -26,4 +27,6 @@ inline void insecure(string message, bool warning = true) #endif } +void insecure_fake(); + #endif /* TOOLS_BENCHMARKING_H_ */ diff --git a/Tools/octetStream.h b/Tools/octetStream.h index df920a302..cd90b0e94 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -35,7 +35,9 @@ class bigint; class FlexBuffer; /** - * Buffer for networking communication with a pointer for sequential reading + * Buffer for network communication with a pointer for sequential reading. + * When sent over the network or stored in a file, the length is prefixed + * as eight bytes in little-endian order. */ class octetStream { diff --git a/Tools/random.cpp b/Tools/random.cpp index 7a0cd1dab..7cf1924f3 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -13,7 +13,7 @@ using namespace std; PRNG::PRNG() : - cnt(0), n_cached_bits(0), cached_bits(0) + cnt(0), n_cached_bits(0), cached_bits(0), initialized(false) { #if defined(__AES__) || !defined(__x86_64__) #ifdef USE_AES @@ -83,6 +83,7 @@ void PRNG::SecureSeed(Player& player) void PRNG::InitSeed() { + initialized = true; #ifdef USE_AES if (useC) { aes_schedule(KeyScheduleC,seed); } @@ -122,6 +123,7 @@ void PRNG::print_state() const void PRNG::hash() { + assert(initialized); #ifndef USE_AES unsigned char tmp[RAND_SIZE + SEED_SIZE]; randombytes_buf_deterministic(tmp, sizeof tmp, seed); diff --git a/Tools/random.h b/Tools/random.h index d22be6e88..5e65d8350 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -61,6 +61,8 @@ class PRNG int n_cached_bits; word cached_bits; + bool initialized; + void hash(); // Hashes state to random and sets cnt=0 void next(); diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index e8026a952..f1158cfa6 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -387,7 +387,7 @@ int generate(ez::ezOptionParser& opt); int main(int argc, const char** argv) { - insecure("preprocessing"); + insecure_fake(); bigint::init_thread(); FakeParams params; diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp new file mode 100644 index 000000000..45e5f3371 --- /dev/null +++ b/Utils/binary-example.cpp @@ -0,0 +1,140 @@ +/* + * binary-example.cpp + * + */ + +#include "GC/TinierSecret.h" +#include "GC/PostSacriSecret.h" +#include "GC/CcdSecret.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/AtlasSecret.h" +#include "GC/TinyMC.h" +#include "GC/VectorInput.h" +#include "GC/PostSacriBin.h" +#include "Protocols/ProtocolSet.h" + +#include "GC/ShareSecret.hpp" +#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/RepPrep.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/ThreadMaster.hpp" +#include "Protocols/Atlas.hpp" +#include "Protocols/MaliciousRepPrep.hpp" +#include "Protocols/Share.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/Shamir.hpp" +#include "Protocols/fake-stuff.hpp" +#include "Machines/ShamirMachine.hpp" +#include "Machines/Rep4.hpp" + +template +void run(int argc, char** argv); + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 3) + { + cerr << "Usage: " << argv[0] + << " [protocol [bit length [threshold]]]" + << endl; + exit(1); + } + + string protocol = "Tinier"; + if (argc > 3) + protocol = argv[3]; + + if (protocol == "Tinier") + run>(argc, argv); + else if (protocol == "Rep3") + run(argc, argv); + else if (protocol == "Rep4") + run(argc, argv); + else if (protocol == "PS") + run(argc, argv); + else if (protocol == "Semi") + run(argc, argv); + else if (protocol == "CCD" or protocol == "MalCCD" or protocol == "Atlas") + { + int nparties = (atoi(argv[2])); + int threshold = (nparties - 1) / 2; + if (argc > 5) + threshold = atoi(argv[5]); + assert(2 * threshold < nparties); + ShamirOptions::s().threshold = threshold; + ShamirOptions::s().nparties = nparties; + + if (protocol == "CCD") + run>>(argc, argv); + else if (protocol == "MalCCD") + run>(argc, argv); + else + run(argc, argv); + } + else + { + cerr << "Unknown protocol: " << protocol << endl; + exit(1); + } +} + +template +void run(int argc, char** argv) +{ + // run 16-bit computation by default + int n_bits = 16; + if (argc > 4) + n_bits = atoi(argv[4]); + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + CryptoPlayer P(N); + + // protocol setup (domain, MAC key if needed etc) + BinaryProtocolSetup setup(P); + + // set of protocols (input, multiplication, output) + BinaryProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; + + int n = 10; + vector a(n), b(n); + + input.reset_all(P); + for (int i = 0; i < n; i++) + input.add_from_all(i + P.my_num(), n_bits); + input.exchange(); + for (int i = 0; i < n; i++) + { + a[i] = input.finalize(0, n_bits); + b[i] = input.finalize(1, n_bits); + } + + protocol.init_mul(); + for (int i = 0; i < n; i++) + protocol.prepare_mul(a[i], b[i], n_bits); + protocol.exchange(); + output.init_open(P, n); + for (int i = 0; i < n; i++) + { + auto c = protocol.finalize_mul(n_bits); + output.prepare_open(c); + } + output.exchange(P); + + cout << "result: "; + for (int i = 0; i < n; i++) + cout << output.finalize_open() << " "; + cout << endl; + + protocol.check(); + output.Check(P); +} diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp new file mode 100644 index 000000000..532d705e4 --- /dev/null +++ b/Utils/mixed-example.cpp @@ -0,0 +1,137 @@ +/* + * mixed-example.cpp + * + */ + +#include "Protocols/ProtocolSet.h" + +#include "Machines/SPDZ.hpp" +#include "Machines/Semi2k.hpp" +#include "Machines/Rep.hpp" +#include "Machines/Rep4.hpp" +#include "Machines/Atlas.hpp" + +template +void run(char** argv); + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 3) + { + cerr << "Usage: " << argv[0] + << " [protocol]" + << endl; + exit(1); + } + + string protocol = "SPDZ2k"; + if (argc > 3) + protocol = argv[3]; + + if (protocol == "SPDZ2k") + run>(argv); + else if (protocol == "Semi2k") + run>(argv); + else if (protocol == "Rep3") + run>(argv); + else if (protocol == "Rep4") + run>(argv); + else if (protocol == "Atlas") + run>>(argv); + else + { + cerr << "Unknown protocol: " << protocol << endl; + exit(1); + } +} + +template +void run(char** argv) +{ + // reduce batch size + OnlineOptions::singleton.bucket_size = 5; + OnlineOptions::singleton.batch_size = 100; + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + CryptoPlayer P(N); + + // protocol setup (domain, MAC key if needed etc) + MixedProtocolSetup setup(P); + + // set of protocols (bit_input, multiplication, output) + MixedProtocolSet set(P, setup); + auto& output = set.output; + auto& bit_input = set.binary.input; + auto& bit_protocol = set.binary.protocol; + auto& bit_output = set.binary.output; + auto& prep = set.preprocessing; + + int n = 10; + int n_bits = 16; + vector a(n), b(n); + + // inputs in binary domain + bit_input.reset_all(P); + for (int i = 0; i < n; i++) + bit_input.add_from_all(i + P.my_num(), n_bits); + bit_input.exchange(); + for (int i = 0; i < n; i++) + { + a[i] = bit_input.finalize(0, n_bits); + b[i] = bit_input.finalize(1, n_bits); + } + + // compute AND in binary domain + bit_protocol.init_mul(); + for (int i = 0; i < n; i++) + bit_protocol.prepare_mul(a[i], b[i], n_bits); + bit_protocol.exchange(); + bit_protocol.check(); + bit_output.init_open(P, n * n_bits); + PointerVector> dabits; + for (int i = 0; i < n; i++) + { + auto c = bit_protocol.finalize_mul(n_bits); + + // mask result with dabits and open + for (int j = 0; j < n_bits; j++) + { + dabits.push_back({}); + auto& dabit = dabits.back(); + prep.get_dabit(dabit.first, dabit.second); + bit_output.prepare_open( + typename T::bit_type::part_type( + dabit.second.get_bit(0) + c.get_bit(j))); + } + } + bit_output.exchange(P); + output.init_open(P, n); + for (int i = 0; i < n; i++) + { + T res; + // unmask via XOR and recombine + for (int j = 0; j < n_bits; j++) + { + typename T::clear masked = bit_output.finalize_open().get_bit(0); + auto mask = dabits.next().first; + res += (mask - mask * masked * 2 + + T::constant(masked, P.my_num(), setup.get_mac_key())) + << j; + } + output.prepare_open(res); + } + output.exchange(P); + bit_output.Check(P); + + cout << "result: "; + for (int i = 0; i < n; i++) + cout << output.finalize_open() << " "; + cout << endl; + + output.Check(P); +} diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 87247fee8..9cae6953f 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -11,8 +11,10 @@ #include "Machines/SPDZ.hpp" #include "Machines/MalRep.hpp" #include "Machines/ShamirMachine.hpp" +#include "Machines/Semi2k.hpp" #include "Protocols/CowGearShare.h" #include "Protocols/CowGearPrep.hpp" +#include "Protocols/ProtocolSet.h" template void run(char** argv, int prime_length); @@ -42,6 +44,8 @@ int main(int argc, char** argv) run>>(argv, prime_length); else if (protocol == "SPDZ2k") run>(argv, 0); + else if (protocol == "Semi2k") + run>(argv, 0); else if (protocol == "Shamir" or protocol == "MalShamir") { int nparties = (atoi(argv[2])); @@ -74,35 +78,14 @@ void run(char** argv, int prime_length) Names N(my_number, n_parties, "localhost", port_base); CryptoPlayer P(N); - // initialize fields - T::clear::init_default(prime_length); - T::clear::next::init_default(prime_length, false); + // protocol setup (domain, MAC key if needed etc) + ProtocolSetup setup(P, prime_length); - // must initialize MAC key for security of some protocols - typename T::mac_key_type mac_key; - T::read_or_generate_mac_key("", P, mac_key); - - // global OT setup - BaseMachine machine; - if (T::needs_ot) - machine.ot_setups.push_back({P}); - - // keeps tracks of preprocessing usage (triples etc) - DataPositions usage; - usage.set_num_players(P.num_players()); - - // output protocol - typename T::MAC_Check output(mac_key); - - // various preprocessing - typename T::LivePrep preprocessing(0, usage); - SubProcessor processor(output, preprocessing, P); - - // input protocol - typename T::Input input(processor, output); - - // multiplication protocol - typename T::Protocol protocol(P); + // set of protocols (input, multiplication, output) + ProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; int n = 1000; vector a(n), b(n); @@ -119,19 +102,23 @@ void run(char** argv, int prime_length) b[i] = input.finalize(1); } - protocol.init_dotprod(&processor); + protocol.init_dotprod(); for (int i = 0; i < n; i++) protocol.prepare_dotprod(a[i], b[i]); protocol.next_dotprod(); protocol.exchange(); c = protocol.finalize_dotprod(n); + + // protocol check before revealing results + protocol.check(); + output.init_open(P); output.prepare_open(c); output.exchange(P); result = output.finalize_open(); cout << "result: " << result << endl; - output.Check(P); - T::LivePrep::teardown(); + // result check after opening + output.Check(P); } diff --git a/Utils/stream-fake-mascot-triples.cpp b/Utils/stream-fake-mascot-triples.cpp index 5aa85a054..517056e72 100644 --- a/Utils/stream-fake-mascot-triples.cpp +++ b/Utils/stream-fake-mascot-triples.cpp @@ -27,13 +27,18 @@ void* run(void* arg) int count = 0; while (true) { - gfpvar triple[3]; - for (int i = 0; i < 2; i++) - triple[i].randomize(G); - triple[2] = triple[0] * triple[1]; - for (int i = 0; i < 3; i++) - files.output_shares(triple[i]); - count++; + for (int i = 0; i < 100000; i++) + { + gfpvar triple[3]; + for (int i = 0; i < 2; i++) + triple[i].randomize(G); + triple[2] = triple[0] * triple[1]; + for (int i = 0; i < 3; i++) + files.output_shares(triple[i]); + count++; + } + // take a break to make them wait + sleep(1); } cerr << "failed after " << count << endl; return 0; @@ -41,7 +46,7 @@ void* run(void* arg) int main() { - insecure("preprocessing"); + insecure_fake(); typedef Share T; int nplayers = 2; int lgp = 128; diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index 074fb3400..749ba2878 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -58,9 +58,6 @@ class YaoEvaluator: public GC::Thread>, int get_n_worker_threads() { return max(1u, thread::hardware_concurrency() / master.machine.nthreads); } - - NamedCommStats comm_stats() - { return super::comm_stats() + player.comm_stats; } }; inline void YaoEvaluator::load_gate(YaoGate& gate) diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index e6ae6cda1..647369a15 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -120,8 +120,3 @@ void YaoGarbler::process_receiver_inputs() receiver_input_keys.pop_front(); } } - -NamedCommStats YaoGarbler::comm_stats() -{ - return super::comm_stats() + player.comm_stats; -} diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 038fe432f..0608336c8 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -71,8 +71,6 @@ class YaoGarbler: public GC::Thread>, int get_threshold() { return master.threshold; } long get_gate_id() { return gate_id(thread_num); } - - NamedCommStats comm_stats(); }; inline YaoGarbler& YaoGarbler::s() diff --git a/Yao/YaoWire.h b/Yao/YaoWire.h index ddaf3b9c2..92f3ec614 100644 --- a/Yao/YaoWire.h +++ b/Yao/YaoWire.h @@ -23,6 +23,10 @@ class YaoWire : public Phase static void xors(GC::Processor& processor, const vector& args, size_t start, size_t end); + template + static void andm(GC::Processor& processor, + const BaseInstruction& instruction); + void XOR(const YaoWire& left, const YaoWire& right) { key_ = left.key_ ^ right.key_; diff --git a/Yao/YaoWire.hpp b/Yao/YaoWire.hpp index bb3b14068..aa04fe357 100644 --- a/Yao/YaoWire.hpp +++ b/Yao/YaoWire.hpp @@ -46,4 +46,24 @@ void YaoWire::xors(GC::Processor& processor, const vector& args, processor.xors(args, start, end); } +template +void YaoWire::andm(GC::Processor& processor, + const BaseInstruction& instruction) +{ + + int unit = GC::Clear::N_BITS; + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + { + auto &dest = processor.S[instruction.get_r(0) + i]; + int n = min(unsigned(unit), instruction.get_n() - i * unit); + dest.resize_regs(n); + for (int j = 0; j < n; j++) + if (processor.C[instruction.get_r(2) + i].get_bit(j)) + dest.get_reg(j) = + processor.S[instruction.get_r(1) + i].get_reg(j); + else + dest.get_reg(j).public_input(0); + } +} + #endif /* YAO_YAOWIRE_HPP_ */ diff --git a/doc/Doxyfile b/doc/Doxyfile index 771f8cf13..3dd299405 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/conf.py b/doc/conf.py index 57f730add..86bb12d46 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -21,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = u'MP-SPDZ' -copyright = u'2021, CSIRO\'s Data61' +copyright = u'2022, CSIRO\'s Data61' author = u'Marcel Keller' # The short X.Y version @@ -185,7 +185,8 @@ breathe_projects = {'mp-spdz': 'xml'} breathe_default_project = 'mp-spdz' import subprocess -subprocess.call('doxygen', shell=True) +if (subprocess.call('doxygen', shell=True)): + raise Exception('doxygen failed') def setup(app): app.add_css_file('custom.css') diff --git a/doc/index.rst b/doc/index.rst index d7a13e941..d2a2c4dcd 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,10 +1,16 @@ Welcome to MP-SPDZ's documentation! =================================== -This documentation provides a reference to the most important -high-level functionality provided by the MP-SPDZ compiler. For a -tutorial and documentation on how to run programs, the -implemented protocols etc. see https://github.com/data61/MP-SPDZ. +If you're new to MP-SPDZ, consider the following: + +1. `Quickstart tutorial `_ +2. `Implemented protocols `_ +3. :ref:`troubleshooting` + +Unlike the `Readme +`_, this +documentation provides a reference for more detailed aspects of the +software. Compilation process ------------------- diff --git a/doc/io.rst b/doc/io.rst index 5184ab338..a4d00cee8 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -83,6 +83,8 @@ covering both client code and server-side high-level code. :py:func:`Compiler.types.MultiArray.reveal_to_clients`. The same functions are available for :py:class:`~Compiler.types.sfix` and :py:class:`~Compiler.types.Array`, respectively. +See also :ref:`client ref` below. + Secret Shares ~~~~~~~~~~~~~ @@ -114,3 +116,11 @@ etc. Note also that all types based on :py:class:`~Compiler.types.sfix`) share the same memory, and that the address is only a base address. This means that vectors will be written to the memory starting at the given address. + +.. _client ref: + +Reference +~~~~~~~~~ + +.. doxygenclass:: Client + :members: diff --git a/doc/low-level.rst b/doc/low-level.rst index 0aaf3708e..c70bf5b65 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -83,109 +83,24 @@ number of parties. .. code-block:: cpp - // initialize fields - T::clear::init_default(prime_length); + ProtocolSetup setup(P, prime_length); We have to use a specific prime for computation modulo a prime. This deterministically generates one of the desired length if necessary. For computation modulo a power of two, this does not do -anything. +anything. Some protocols use an information-theoretic tag that is +constant throughout the protocol. This code reads it from storage if +available or generates a fresh one otherwise. .. code-block:: cpp - T::clear::next::init_default(prime_length, false); + ProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; -For computation modulo a prime, it is more efficient to use Montgomery -representation, which is not compatible with the MASCOT offline phase -however. This line initializes another field instance for MASCOT -without using Montgomery representation. - -.. code-block:: cpp - - // must initialize MAC key for security of some protocols - typename T::mac_key_type mac_key; - T::read_or_generate_mac_key("", P, mac_key); - -Some protocols use an information-theoretic tag that is constant -throughout the protocol. This codes reads it from storage if available -or generates a fresh one otherwise. - -.. code-block:: cpp - - // global OT setup - BaseMachine machine; - if (T::needs_ot) - machine.ot_setups.push_back({P}); - -Many protocols for a dishonest majority use oblivious transfer. This -block runs a few instances to seed the oblivious transfer -extension. The resulting setup only works for one thread. For several -threads, you need to add sufficiently many instances to -:member:`ot_setups` and set :member:`BaseMachine::thread_num` -(thread-local) to a different consecutive number in every thread. - -.. code-block:: cpp - - // keeps tracks of preprocessing usage (triples etc) - DataPositions usage; - usage.set_num_players(P.num_players()); - -To help keeping track of the required preprocessing, it is necessary -to initialize preprocessing instances with a :class:`DataPositions` -variable that will store the usage. - -.. code-block:: cpp - - // initialize binary computation - T::bit_type::mac_key_type::init_field(); - typename T::bit_type::mac_key_type binary_mac_key; - T::bit_type::part_type::read_or_generate_mac_key("", P, binary_mac_key); - GC::ShareThread thread(N, - OnlineOptions::singleton, P, binary_mac_key, usage); - -While this example only uses arithmetic computation, you need to -initialize binary computation as well unless you use the compile-time -option ``NO_MIXED_CIRCUITS``. - -.. code-block:: cpp - - // output protocol - typename T::MAC_Check output(mac_key); - -Some output protocols use the MAC key to check the correctness. - -.. code-block:: cpp - - // various preprocessing - typename T::LivePrep preprocessing(0, usage); - SubProcessor processor(output, preprocessing, P); - -In this example we use live preprocessing, but it is also possible to -read preprocessing data from disk by using :class:`Sub_Data_Files` -instead. You can use a live preprocessing instances to generate -preprocessing data independently, but many protocols require that a -:class:`SubProcessor` instance has been created as well. The latter -essentially glues an instance of the output and the preprocessing -protocol together, which is necessary for Beaver-based multiplication -protocols. - -.. code-block:: cpp - - // input protocol - typename T::Input input(processor, output); - -Some input protocols depend on preprocessing and an output protocol, -which is reflect in the standard constructor. Other constructors are -available depending on the protocol. - -.. code-block:: cpp - - // multiplication protocol - typename T::Protocol protocol(P); - -This instantiates a multiplication protocol. :var:`P` is required -because some protocols start by exchanging keys for pseudo-random -secret sharing. +The :class:`ProtocolSet` contains one instance for every essential +protocol step. .. code-block:: cpp @@ -235,6 +150,14 @@ The initialization of the multiplication sets the preprocessing and output instances to use in Beaver multiplication. :func:`next_dotprod` separates dot products in the data preparation phase. +.. code-block:: cpp + + protocol.check(); + +Some protocols require a check of all multiplications up to a certain +point. To guarantee that outputs do not reveal secret information, it +has to be run before using the output protocol. + .. code-block:: cpp output.init_open(P); @@ -245,8 +168,8 @@ separates dot products in the data preparation phase. cout << "result: " << result << endl; output.Check(P); -The output protocol follows the same blueprint except that it is -necessary to call the checking in order to verify the outputs. +The output protocol follows the same blueprint as the multiplication +protocol. .. code-block:: cpp @@ -281,6 +204,9 @@ Domain Types the time of writing, 4, 8, 28, 40, 63, and 128 are supported if the storage type is large enough. + +.. _share-type-reference: + Share Types ------------ @@ -385,6 +311,28 @@ Share Types ``MaliciousShamirShare`` or ``MaliciousRep3Share``. +Protocol Setup +-------------- + +.. doxygenclass:: ProtocolSetup + :members: + +.. doxygenclass:: ProtocolSet + :members: + +.. doxygenclass:: BinaryProtocolSetup + :members: + +.. doxygenclass:: BinaryProtocolSet + :members: + +.. doxygenclass:: MixedProtocolSetup + :members: + +.. doxygenclass:: MixedProtocolSet + :members: + + Protocol Interfaces ------------------- diff --git a/doc/networking.rst b/doc/networking.rst index 16908681a..a1c61b98d 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -18,7 +18,7 @@ individually setting ports: coordination server being run as a thread of party 0. The hostname of the coordination server has to be given with the command-line parameter ``--hostname``, and the coordination server runs on the - base port number minus one, thus defaulting to 4999. Furthermore, you + base port number, thus defaulting to 5000. Furthermore, you can specify a party's listening port using ``--my-port``. 2. The parties read the information from a local file, which needs to @@ -40,7 +40,9 @@ change this by either using ``--encrypted/-e`` or If using encryption, the certificates (``Player-Data/*.pem``) must be the same on all hosts, and you have to run ``c_rehash Player-Data`` on -all of them. +all of them. ``Scripts/setup-ssl.sh`` can be used to generate the +necessary certificates. The common name has to be ``P`` +for computing parties and ``C`` for clients. .. _network-reference: diff --git a/doc/non-linear.rst b/doc/non-linear.rst index 5fe8df1f6..bcdbbd3ae 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -7,8 +7,8 @@ domains (modulus other than two) only comes in three flavors throughout MP-SPDZ: Unknown prime modulus - This approach goes back to `Catrina and Saxena - `_. It crucially relies on + This approach goes back to `Catrina and de Hoogh + `_. It crucially relies on the use of secret random bits in the arithmetic domain. Enough such bits allow to mask a secret value so that it is secure to reveal the masked value. This can then be split in bits as it is diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 3dadcfae6..1441e3524 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -16,7 +16,7 @@ is a thread created by control flow instructions such as The exceptions to the general rule are edaBit generation with malicious security and AND triples with malicious security and honest -majority, both when use bucket size three. Bucket size three implies +majority, both when using bucket size three. Bucket size three implies batches of over a million to achieve 40-bit statistical security, and in honest-majority binary computation the item size is 64, which makes the actual batch size 64 million triples. In multithreaded programs, @@ -27,3 +27,65 @@ jump whenever another batch is generated. Note that, while some protocols are flexible with the batch size and can thus be controlled using ``-b``, others mandate a batch size, which can be as large as a million. + + +Separate preprocessing +====================== + +It is possible to separate out the preprocessing from the +input-dependent ("online") phase. This is done by either option ``-F`` +or ``-f`` on the virtual machines. In both cases, the preprocessing +data is read from files, either all data per type from a single file +(``-F``) or one file per thread (``-f``). The latter allows to use +named pipes. + +The file name depends on the protocol and the computation domain. It +is generally ``/--/--P[-T]``. For example, the +triples for party 1 in SPDZ modulo a 128-bit prime can be found in +``Player-Data/2-p-128/Triples-p-P1``. The protocol shorthand can be +found by calling ``::type_short()``. See +:ref:`share-type-reference` for a description of the share types. + +Preprocessing files start with a header describing the protocol and +computation domain to avoid errors due to mismatches. The header is as +follows: + +- Length to follow (little-endian 8-byte number) +- Protocol descriptor +- Domain descriptor + +The protocol descriptor is defined by ``::type_string()``. For SPDZ modulo a prime it is ``SPDZ gfp``. + +The domain descriptor depends on the kind of domain: + +Modulo a prime + Serialization of the prime + + - Sign bit (0 as 1 byte) + - Length to follow (little-endian 4-byte number) + - Prime (big-endian) + +Modulo a power of two: + Exponent (little-endian 4-byte number) + +:math:`GF(2^n)` + - Storage size in bytes (little-endian 8-byte number). Default is 16. + - :math:`n` (little-endian 4-byte number) + +As an example, the following output of ``hexdump -C`` describes SPDZ +modulo the default 128-bit prime +(170141183460469231731687303715885907969):: + + 00000000 1d 00 00 00 00 00 00 00 53 50 44 5a 20 67 66 70 |........SPDZ gfp| + 00000010 00 10 00 00 00 80 00 00 00 00 00 00 00 00 00 00 |................| + 00000020 00 00 1b 80 01 |.....| + 00000025 + + +``Fake-Offline.x`` generates preprocessing data insecurely for a range +of protocols, and ``{mascot,cowgear,mal-shamir}-offline.x`` generate +sufficient preprocessing data for a specific high-level program with +MASCOT, CowGear, and malicious Shamir secret sharing, respectively. diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 1c096d985..6a79ea198 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -1,3 +1,5 @@ +.. _troubleshooting: + Troubleshooting --------------- @@ -57,10 +59,23 @@ second batch is necessary the cost shoots up. Other preprocessing methods allow for a variable batch size, which can be changed using ``-b``. Smaller batch sizes generally reduce the communication cost while potentially increasing the number of communication rounds. Try -adding ``-b 10`` to the virtal machine (or script) arguments for very +adding ``-b 10`` to the virtual machine (or script) arguments for very short computations. +Disparities in round figures +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The number of virtual machine rounds given by the compiler are not an +exact prediction of network rounds but the number of relevant protocol +calls (such as multiplication, input, output etc) in the program. The +actual number of network rounds is determined by the choice of +protocol, which might use several rounds per protocol +call. Furthermore, communication at the beginning and the end of a +computation such as random key distribution and MAC checks further +increase the number of network rounds. + + Handshake failures ~~~~~~~~~~~~~~~~~~ @@ -82,8 +97,8 @@ use the client facility. Connection failures ~~~~~~~~~~~~~~~~~~~ -MP-SPDZ requires at least one TCP port per party to be open to other -parties. In the default setting, it's 4999 and 5000 on party 0, and +MP-SPDZ requires one TCP port per party to be open to other +parties. In the default setting, it's 5000 on party 0, and 5001 on party 1 etc. You change change the base port (5000) using ``--portnumbase`` and individual ports for parties using ``--my-port``. The scripts in use a random base port number, which you From aaa90a20bbf33e4ebac270968200f04255b46eca Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 12 Jan 2022 18:57:40 +1100 Subject: [PATCH 018/221] Don't overwrite persistence files at beginning. --- Processor/Machine.hpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index d7d1a3ec3..909a8f3a2 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -94,13 +94,6 @@ Machine::Machine(int my_number, Names& playerNames, load_schedule(progname_str); - // remove persistence if necessary - for (auto& prog : progs) - { - if (prog.writes_persistance) - ofstream(Binary_File_IO::filename(my_number), ios::out); - } - #ifdef VERBOSE progs[0].print_offline_cost(); #endif From 962919c3cf592fb2c79a6282b2d62842d35e1dc9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 13 Jan 2022 14:09:53 +1100 Subject: [PATCH 019/221] Bug in regint optimizer. --- Compiler/allocator.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 7ce9896b1..9871d97fa 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -583,13 +583,6 @@ def run(self, instructions): self.cache[inst.args[0]] = res instructions[i] = ldint(inst.args[0], res, add_to_prog=False) - elif isinstance(inst, addint_class): - if inst.args[1] in self.cache and \ - self.cache[inst.args[1]] == 0: - instructions[i] = inst.args[0].link(inst.args[2]) - elif inst.args[2] in self.cache and \ - self.cache[inst.args[2]] == 0: - instructions[i] = inst.args[0].link(inst.args[1]) elif isinstance(inst, IndirectMemoryInstruction): if inst.args[1] in self.cache: instructions[i] = inst.get_direct(self.cache[inst.args[1]]) @@ -606,7 +599,4 @@ def run(self, instructions): if op == 0: instructions[i] = ldsi(inst.args[0], 0, add_to_prog=False) - elif op == 1: - instructions[i] = None - inst.args[0].link(inst.args[1]) instructions[:] = list(filter(lambda x: x is not None, instructions)) From f343d73b25ae6ddc62d47aaf9cd146362bfcbf47 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 13 Jan 2022 14:10:06 +1100 Subject: [PATCH 020/221] Bug in for_range_opt. --- Compiler/library.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Compiler/library.py b/Compiler/library.py index 7bab1951a..4f6c2de16 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -12,6 +12,7 @@ import random import collections import operator +import copy from functools import reduce def get_program(): @@ -1031,12 +1032,14 @@ def _(i): state = tuplify(initializer()) k = 0 block = get_block() + pre = copy.copy(loop_body.__globals__) while (not util.is_constant(n_loops) or k < n_loops) \ and (len(get_block()) < budget or k == 0) \ and block is get_block(): j = i + k state = reducer(tuplify(loop_body(j)), state) k += 1 + _link(pre, loop_body.__globals__) r = reducer(mem_state, state) write_state_to_memory(r) global n_opt_loops @@ -1395,9 +1398,12 @@ def wrapped_loop(): def _run_and_link(function, g=None): if g is None: g = function.__globals__ - import copy pre = copy.copy(g) res = function() + _link(pre, g) + return res + +def _link(pre, g): if g: from .types import _single for name, var in pre.items(): @@ -1407,7 +1413,6 @@ def _run_and_link(function, g=None): raise CompilerError('cannot reassign constants in blocks') if id(new_var) != id(var): new_var.link(var) - return res def do_while(loop_fn, g=None): """ Do-while loop. The loop is stopped if the return value is zero. From 0f9d5de6979915ce35e57ea747bdb5214d2dfd61 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 12 Jan 2022 20:11:28 +1100 Subject: [PATCH 021/221] Allow overwriting of persistence files. --- Compiler/instructions.py | 5 +++-- Compiler/types.py | 42 +++++++++++++++++++++++++----------- Processor/Binary_File_IO.h | 3 ++- Processor/Binary_File_IO.hpp | 18 ++++++++++++++-- Processor/Instruction.hpp | 4 ++-- Processor/Machine.hpp | 13 +++++++++++ Processor/Processor.h | 2 +- Processor/Processor.hpp | 6 ++++-- 8 files changed, 70 insertions(+), 23 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 1533fc523..a85fb25ad 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1727,14 +1727,15 @@ class writesharestofile(base.IOInstruction): """ Write shares to ``Persistence/Transactions-P.data`` (appending at the end). - :param: number of shares (int) + :param: number of arguments to follow / number of shares plus one (int) + :param: position (regint, -1 for appending) :param: source (sint) :param: (repeat from source)... """ __slots__ = [] code = base.opcodes['WRITEFILESHARE'] - arg_format = itertools.repeat('s') + arg_format = tools.chain(['ci'], itertools.repeat('s')) def has_var_args(self): return True diff --git a/Compiler/types.py b/Compiler/types.py index 33df2e373..77de5d71e 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2329,16 +2329,20 @@ def read_from_file(cls, start, n_items): return stop, shares @staticmethod - def write_to_file(shares): + def write_to_file(shares, position=None): """ Write shares to ``Persistence/Transactions-P.data`` (appending at the end). - :param: shares (list or iterable of sint) + :param shares: (list or iterable of sint) + :param position: start position (int/regint/cint), + defaults to end of file """ for share in shares: assert isinstance(share, sint) assert share.size == 1 - writesharestofile(*shares) + if position is None: + position = -1 + writesharestofile(regint.conv(position), *shares) @vectorized_classmethod def load_mem(cls, address, mem_type=None): @@ -3922,13 +3926,15 @@ def read_from_file(cls, *args, **kwargs): return stop, [cls._new(x) for x in shares] @classmethod - def write_to_file(cls, shares): + def write_to_file(cls, shares, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. - :param: shares (list or iterable of sfix) + :param shares: (list or iterable of sfix) + :param position: start position (int/regint/cint), + defaults to end of file """ - cls.int_type.write_to_file([x.v for x in shares]) + cls.int_type.write_to_file([x.v for x in shares], position) def store_in_mem(self, address): """ Store in memory by public address. """ @@ -5389,11 +5395,14 @@ def read_from_file(self, start): self.assign(shares) return stop - def write_to_file(self): + def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. + + :param position: start position (int/regint/cint), + defaults to end of file """ - self.value_type.write_to_file(list(self)) + self.value_type.write_to_file(list(self), position) def __add__(self, other): """ Vector addition. @@ -5723,13 +5732,20 @@ def input_from(self, player, budget=None, raw=False): def _(i): self[i].input_from(player, budget=budget, raw=raw) - def write_to_file(self): + def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. + + :param position: start position (int/regint/cint), + defaults to end of file """ @library.for_range(len(self)) def _(i): - self[i].write_to_file() + if position is None: + my_pos = None + else: + my_pos = position + i * self[i].total_size() + self[i].write_to_file(my_pos) def read_from_file(self, start): """ Read content from ``Persistence/Transactions-P.data``. diff --git a/Processor/Binary_File_IO.h b/Processor/Binary_File_IO.h index 4e38cd161..c19a129af 100644 --- a/Processor/Binary_File_IO.h +++ b/Processor/Binary_File_IO.h @@ -27,7 +27,8 @@ class Binary_File_IO * Throws file_error. */ template - void write_to_file(const string filename, const vector< T >& buffer); + void write_to_file(const string filename, const vector& buffer, + long start_pos); /* * Read from posn in the filename the binary values until the buffer is full. diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index 9878f4a6b..ef735279a 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -14,18 +14,32 @@ inline string Binary_File_IO::filename(int my_number) } template -void Binary_File_IO::write_to_file(const string filename, const vector< T >& buffer) + +void Binary_File_IO::write_to_file(const string filename, + const vector& buffer, long start_pos) { ofstream outf; - outf.open(filename, ios::out | ios::binary | ios::app); + outf.open(filename, ios::out | ios::binary | ios::ate | ios::in); if (outf.fail()) { throw file_error(filename); } + if (start_pos != -1) + { + long write_pos = start_pos * T::size(); + // fill with zeros if needed + for (long i = outf.tellp(); i < write_pos; i++) + outf.put(0); + outf.seekp(write_pos); + } + for (unsigned int i = 0; i < buffer.size(); i++) { buffer[i].output(outf, false); } + if (outf.fail()) + throw runtime_error("failed writing to " + filename); + outf.close(); } diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e45a85045..25fa666f4 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -273,7 +273,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) get_vector(2, start, s); break; // open instructions + read/write instructions with variable length args - case WRITEFILESHARE: case OPEN: case GOPEN: case MULS: @@ -376,6 +375,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case BITDECINT: case EDABIT: case SEDABIT: + case WRITEFILESHARE: num_var_args = get_int(s) - 1; r[0] = get_int(s); get_vector(num_var_args, start, s); @@ -1175,7 +1175,7 @@ inline void Instruction::execute(Processor& Proc) const break; case WRITEFILESHARE: // Write shares to file system - Proc.write_shares_to_file(start); + Proc.write_shares_to_file(Proc.read_Ci(r[0]), start); break; case READFILESHARE: // Read shares from file system diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 909a8f3a2..a43c9d475 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -94,6 +94,19 @@ Machine::Machine(int my_number, Names& playerNames, load_schedule(progname_str); + // initialize persistence if necessary + for (auto& prog : progs) + { + if (prog.writes_persistance) + { + string filename = Binary_File_IO::filename(my_number); + ifstream pers(filename); + if (pers.fail()) + ofstream pers(filename, ios::binary); + break; + } + } + #ifdef VERBOSE progs[0].print_offline_cost(); #endif diff --git a/Processor/Processor.h b/Processor/Processor.h index a78058cd1..c91b677bf 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -239,7 +239,7 @@ class Processor : public ArithmeticProcessor // Read and write secret numeric data to file (name hardcoded at present) void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector& data_registers); - void write_shares_to_file(const vector& data_registers); + void write_shares_to_file(long start_pos, const vector& data_registers); cint get_inverse2(unsigned m); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index caea1e678..c55a6dfc1 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -370,7 +370,9 @@ void Processor::read_shares_from_file(int start_file_posn, int end_ // Append share data in data_registers to end of file. Expects Persistence directory to exist. template -void Processor::write_shares_to_file(const vector& data_registers) { +void Processor::write_shares_to_file(long start_pos, + const vector& data_registers) +{ string filename = binary_file_io.filename(P.my_num()); unsigned int size = data_registers.size(); @@ -382,7 +384,7 @@ void Processor::write_shares_to_file(const vector& data_regist inpbuf[i] = get_Sp_ref(data_registers[i]); } - binary_file_io.write_to_file(filename, inpbuf); + binary_file_io.write_to_file(filename, inpbuf, start_pos); } template From 9fcffad831ca4c452e6d9c322ce49d47f64187ef Mon Sep 17 00:00:00 2001 From: jvmncs Date: Tue, 18 Jan 2022 08:49:00 -0500 Subject: [PATCH 022/221] approx_sigmoid is attributed to an earlier paper --- Compiler/ml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 7e53a78f8..5c4664be8 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -104,7 +104,7 @@ def sigmoid_prime(x): @vectorize def approx_sigmoid(x, n=3): """ Piece-wise approximate sigmoid as in - `Dahl et al. `_ + `Hong et al. `_ :param x: input :param n: number of pieces, 3 (default) or 5 From b5ff3735999b43f58850f39d7458cd9b8def80ed Mon Sep 17 00:00:00 2001 From: shareong <740310627@qq.com> Date: Thu, 20 Jan 2022 22:42:43 +0800 Subject: [PATCH 023/221] fix close socket error --- Networking/Player.cpp | 4 +++- Networking/sockets.cpp | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 031685697..c9e657da2 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -248,7 +248,9 @@ PlainPlayer::~PlainPlayer() /* Close down the sockets */ for (auto socket : sockets) close_client_socket(socket); - close_client_socket(send_to_self_socket); + #ifndef PPC_COMMUNICATION + close_client_socket(send_to_self_socket); + #endif } } diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index 124ba18ca..a1fe34c84 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -133,6 +133,6 @@ void close_client_socket(int socket) { char tmp[1000]; sprintf(tmp, "close(%d)", socket); - // error(tmp); + error(tmp); } } From fc3a2a0f320e27910bbfdba1c96ef64e6633337f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 24 Jan 2022 13:24:51 +1100 Subject: [PATCH 024/221] Personal array functionality. --- Compiler/types.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/Compiler/types.py b/Compiler/types.py index 77de5d71e..3fdc6cf05 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1698,6 +1698,12 @@ def _san(self, other): def _div_san(self): return self._v.conv((library.get_player_id() == self.player)._v).if_else(self._v, 1) + def __setitem__(self, index, value): + self._san(value) + self._v[index] = value + + __getitem__ = lambda self, index: personal(self.player, self._v[index]) + __add__ = lambda self, other: personal(self.player, self._san(other) + other) __sub__ = lambda self, other: personal(self.player, self._san(other) - other) __mul__ = lambda self, other: personal(self.player, self._san(other) * other) @@ -5500,6 +5506,14 @@ def binary_output(self, player=None): """ self.get_vector().binary_output(player) + def reveal_to(self, player): + """ Reveal secret array to :py:obj:`player`. + + :param player: public integer (int/regint/cint) + :returns: :py:class:`personal` containing an array + """ + return personal(player, self.create_from(self[:].reveal_to(player)._v)) + def sort(self, n_threads=None): """ Sort in place using Batchers' odd-even merge mergesort From 5584e1818dd52eefbe89b3b0e37aee52d11daa2d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 31 Jan 2022 14:29:44 +1100 Subject: [PATCH 025/221] Bugs in binary register conversion. --- GC/Secret.hpp | 2 +- Yao/YaoGarbleWire.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 01c70247c..88f1926a6 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -140,7 +140,7 @@ T& GC::Secret::get_new_reg() template void Secret::load_clear(int n, const Integer& x) { - if ((unsigned)n < 8 * sizeof(x) and abs(x.get()) > (1LL << n)) + if ((unsigned)n < 8 * sizeof(x) and (unsigned long) abs(x.get()) > (1ul << n)) throw out_of_range("public value too long"); #ifdef DEBUG_ROUNDS2 cout << "secret from integer " << hex << this << dec << " " << endl; diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 37931df43..05a8646db 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -241,7 +241,7 @@ void YaoGarbleWire::convcbit2s(GC::Processor& processor, int n = min(unsigned(unit), instruction.get_n() - i * unit); dest.resize_regs(n); for (int j = 0; j < n; j++) - dest.get_reg(i).public_input( + dest.get_reg(j).public_input( processor.C[instruction.get_r(1) + i].get_bit(j)); } } From d50e97fde91369256caf339c0c55e19071d542f2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 1 Feb 2022 13:54:53 +1100 Subject: [PATCH 026/221] Simplify code. --- Compiler/GC/types.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 53da15ba2..94d520825 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -111,8 +111,7 @@ def load_mem(cls, address, mem_type=None, size=None): if mem_type == 'sd': return cls.load_dynamic_mem(address) else: - for i in range(res.size): - cls.mem_op(cls.load_inst, res[i], address + i) + cls.mem_op(cls.load_inst, res, address) return res def store_in_mem(self, address): self.mem_op(self.store_inst, self, address) From 61d40b7d8392ee836e973df7a26bc53154c3d6a7 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 4 Feb 2022 11:16:12 +1100 Subject: [PATCH 027/221] Fix bugs in mathematical functions using binary circuits. --- Compiler/GC/types.py | 10 +++++++++- Compiler/mpc_math.py | 10 +++++----- Compiler/types.py | 1 + 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 94d520825..13619c7f9 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -811,7 +811,7 @@ def n_bits(self): def store_in_mem(self, address): for i, x in enumerate(self.elements()): x.store_in_mem(address + i) - def bit_decompose(self, n_bits=None, security=None): + def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None): return self.v[:n_bits] bit_compose = from_vec def reveal(self): @@ -1267,6 +1267,9 @@ class sbitfixvec(_fix): int_type = sbitintvec.get_type(sbitfix.k) float_type = type(None) clear_type = cbitfix + @property + def bit_type(self): + return type(self.v[0]) @classmethod def set_precision(cls, f, k=None): super(sbitfixvec, cls).set_precision(f=f, k=k) @@ -1284,6 +1287,8 @@ def __init__(self, value=None, *args, **kwargs): if isinstance(value, (list, tuple)): self.v = self.int_type.from_vec(sbitvec([x.v for x in value])) else: + if isinstance(value, sbitvec): + value = self.int_type(value) super(sbitfixvec, self).__init__(value, *args, **kwargs) def elements(self): return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()] @@ -1293,9 +1298,12 @@ def mul(self, other): else: return super(sbitfixvec, self).mul(other) def __xor__(self, other): + if util.is_zero(other): + return self return self._new(self.v ^ other.v) def __and__(self, other): return self._new(self.v & other.v) + __rxor__ = __xor__ @staticmethod def multipliable(other, k, f, size): class cls(_fix): diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 322989b34..47253dc43 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -290,7 +290,7 @@ class my_fix(type(a)): # how many bits to use from integer part n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) n_bits = a.f + n_int_bits - sint = types.sint + sint = a.int_type if types.program.options.ring and not as19: intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k @@ -367,17 +367,17 @@ class my_fix(type(a)): bits = a.v.bit_decompose(a.k, maybe_mixed=True) lower = sint.bit_compose(bits[:a.f]) higher_bits = bits[a.f:n_bits] - s = sint.conv(bits[-1]) + s = a.bit_type.conv(bits[-1]) bits_to_check = bits[n_bits:-1] if not as19: - c = types.sfix._new(lower, k=a.k, f=a.f) + c = a._new(lower, k=a.k, f=a.f) assert(len(higher_bits) == n_bits - a.f) pow2_bits = [sint.conv(x) for x in higher_bits] d = floatingpoint.Pow2_from_bits(pow2_bits) g = exp_from_parts(d, c) - small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, + small_result = a._new(g.v.round(a.f + 2 ** n_int_bits, 2 ** n_int_bits, signed=False, - nearest=types.sfix.round_nearest), + nearest=a.round_nearest), k=a.k, f=a.f) if zero_output: t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), diff --git a/Compiler/types.py b/Compiler/types.py index 3fdc6cf05..0063fdc16 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -4274,6 +4274,7 @@ class sfix(_fix): :params _v: int/float/regint/cint/sint/sfloat """ int_type = sint + bit_type = sintbit clear_type = cfix @vectorized_classmethod From 33991a91c43e9d73ea0c5553df54c44a000ee919 Mon Sep 17 00:00:00 2001 From: Shareong <740310627@qq.com> Date: Tue, 15 Feb 2022 19:47:25 +0800 Subject: [PATCH 028/221] send job type to gateway --- Networking/Player.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index c9e657da2..f3064616a 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -293,6 +293,7 @@ void PlainPlayer::setup_sockets(const vector& names, auto pn = job_id + "-" + id_base + "-" + to_string(i); cerr << "Gateway pn: " << pn << endl; set_up_client_socket(sockets[i],names[i].c_str(),ports[i]); + octetStream("MPC").Send(sockets[i]); octetStream(pn).Send(sockets[i]); } #else From 0f7020d791a667ede375aa365f109ac286e89d43 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 17 Feb 2022 13:21:19 +1100 Subject: [PATCH 029/221] Semi-honest computation based on threshold semi-homomorphic encryption. --- CHANGELOG.md | 13 +- CONFIG | 1 + Compiler/GC/instructions.py | 37 ++- Compiler/GC/types.py | 28 +- Compiler/allocator.py | 20 +- Compiler/instructions.py | 177 +++++++------ Compiler/instructions_base.py | 57 ++++- Compiler/library.py | 57 ++++- Compiler/ml.py | 351 +++++++++++++++++++++++--- Compiler/oram.py | 1 + Compiler/program.py | 11 +- Compiler/types.py | 97 +++++-- FHE/FHE_Keys.cpp | 16 +- FHE/FHE_Keys.h | 2 + FHE/FHE_Params.cpp | 15 ++ FHE/FHE_Params.h | 6 +- FHE/NTL-Subs.cpp | 11 +- FHE/NTL-Subs.h | 2 +- FHE/NoiseBounds.cpp | 5 +- FHE/Ring_Element.cpp | 1 + FHE/Rq_Element.cpp | 9 +- FHE/Rq_Element.h | 8 +- FHEOffline/DataSetup.cpp | 2 +- FHEOffline/Multiplier.cpp | 7 + FHEOffline/Multiplier.h | 3 + FHEOffline/PairwiseSetup.cpp | 14 +- FHEOffline/PairwiseSetup.h | 2 +- FHEOffline/SimpleDistDecrypt.cpp | 8 + FHEOffline/SimpleDistDecrypt.h | 1 + FHEOffline/TemiSetup.cpp | 59 +++++ FHEOffline/TemiSetup.h | 34 +++ GC/Memory.h | 2 +- GC/ShareSecret.h | 1 + GC/TinySecret.h | 1 + GC/instructions.h | 2 +- Machines/ShamirMachine.hpp | 1 + Machines/temi-party.cpp | 37 +++ Makefile | 5 +- Math/FixedVec.h | 5 - Math/Zp_Data.h | 2 +- Math/gf2n.cpp | 16 +- Math/mpn_fixed.h | 6 + Networking/Player.h | 1 + OT/BaseOT.cpp | 18 +- Processor/Binary_File_IO.hpp | 13 +- Processor/Input.h | 19 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 2 + Processor/Instruction.hpp | 46 ++-- Processor/Machine.hpp | 15 +- Processor/Memory.h | 4 +- Processor/Memory.hpp | 7 +- Processor/PrivateOutput.h | 12 +- Processor/PrivateOutput.hpp | 33 ++- Processor/Processor.h | 6 +- Processor/Processor.hpp | 36 ++- Processor/Program.cpp | 2 +- Processor/Program.h | 4 +- Processor/SpecificPrivateOutput.h | 65 +++++ Programs/Source/falcon_alex.mpc | 100 ++++++++ Programs/Source/keras_cifar_lenet.mpc | 45 ++++ Programs/Source/keras_mnist_dense.mpc | 3 +- Programs/Source/keras_mnist_lenet.mpc | 13 + Programs/Source/mnist_full_A.mpc | 6 + Programs/Source/mnist_full_C.mpc | 8 +- Protocols/Atlas.hpp | 6 + Protocols/Hemi.hpp | 2 +- Protocols/HemiMatrixPrep.h | 5 +- Protocols/HemiMatrixPrep.hpp | 68 +++-- Protocols/HemiPrep.h | 3 + Protocols/HemiPrep.hpp | 14 + Protocols/HemiShare.h | 1 + Protocols/LowGearKeyGen.hpp | 8 +- Protocols/MAC_Check.h | 17 +- Protocols/MAC_Check.hpp | 1 + Protocols/MAC_Check_Base.h | 4 + Protocols/MalRepRingShare.h | 4 +- Protocols/MaliciousRep3Share.h | 3 +- Protocols/MaliciousShamirPO.h | 3 +- Protocols/MaliciousShamirShare.h | 4 +- Protocols/MamaShare.h | 6 - Protocols/PostSacriRepFieldShare.h | 4 +- Protocols/PostSacriRepRingShare.h | 4 +- Protocols/ProtocolSet.h | 25 +- Protocols/Rep3Share.h | 7 +- Protocols/Rep3Share2k.h | 3 +- Protocols/Rep4Input.h | 1 - Protocols/Rep4Input.hpp | 6 - Protocols/Replicated.h | 7 - Protocols/Replicated.hpp | 6 +- Protocols/ReplicatedPrep.hpp | 33 ++- Protocols/ReplicatedPrivateOutput.h | 26 -- Protocols/ReplicatedPrivateOutput.hpp | 30 --- Protocols/Semi.h | 6 + Protocols/SemiInput.h | 29 +-- Protocols/SemiInput.hpp | 62 ++++- Protocols/Shamir.h | 1 - Protocols/Shamir.hpp | 36 +-- Protocols/ShamirInput.h | 7 +- Protocols/ShamirInput.hpp | 33 ++- Protocols/ShamirMC.h | 4 + Protocols/ShamirMC.hpp | 13 + Protocols/ShamirShare.h | 7 +- Protocols/Share.h | 1 + Protocols/ShareInterface.h | 1 + Protocols/SpdzWiseInput.h | 3 - Protocols/SpdzWiseInput.hpp | 18 -- Protocols/SpdzWiseMC.h | 2 +- Protocols/SpdzWisePrep.hpp | 1 - Protocols/TemiPrep.h | 72 ++++++ Protocols/TemiPrep.hpp | 129 ++++++++++ Protocols/TemiShare.h | 42 +++ Protocols/fake-stuff.hpp | 9 +- README.md | 33 ++- Scripts/prep-usage.py | 23 ++ Scripts/temi.sh | 8 + Scripts/test_tutorial.sh | 2 +- Tools/Buffer.h | 4 + Tools/Exceptions.cpp | 4 +- Tools/Exceptions.h | 2 +- Tools/octetStream.h | 2 + Utils/binary-example.cpp | 4 +- Utils/mixed-example.cpp | 4 +- Utils/paper-example.cpp | 4 +- doc/instructions.rst | 10 +- doc/low-level.rst | 5 + doc/non-linear.rst | 2 +- doc/preprocessing.rst | 34 ++- doc/requirements.txt | 1 + 129 files changed, 1973 insertions(+), 539 deletions(-) create mode 100644 FHEOffline/TemiSetup.cpp create mode 100644 FHEOffline/TemiSetup.h create mode 100644 Machines/temi-party.cpp create mode 100644 Processor/SpecificPrivateOutput.h create mode 100644 Programs/Source/falcon_alex.mpc create mode 100644 Programs/Source/keras_cifar_lenet.mpc delete mode 100644 Protocols/ReplicatedPrivateOutput.h delete mode 100644 Protocols/ReplicatedPrivateOutput.hpp create mode 100644 Protocols/TemiPrep.h create mode 100644 Protocols/TemiPrep.hpp create mode 100644 Protocols/TemiShare.h create mode 100755 Scripts/prep-usage.py create mode 100755 Scripts/temi.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b75d24f8..6a0406a8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,17 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. -## 0.2.9 (Jan 11, 2021) +## 0.3.0 (Feb 17, 2022) + +- Semi-honest computation based on threshold semi-homomorphic encryption +- Batch normalization backward propagation +- AlexNet for CIFAR-10 +- Specific private output protocols +- Semi-honest additive secret sharing without communication +- Sending of personal values +- Allow overwriting of persistence files +- Protocol signature in persistence files + +## 0.2.9 (Jan 11, 2022) - Disassembler - Run-time parameter for probabilistic truncation error diff --git a/CONFIG b/CONFIG index ba6855ea9..05b3683d5 100644 --- a/CONFIG +++ b/CONFIG @@ -42,6 +42,7 @@ else AVX_OT = 1 endif else +ARCH = AVX_OT = 0 endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index fc64ae2d2..ef9c14a3f 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -497,7 +497,7 @@ class movsb(NonVectorInstruction): code = opcodes['MOVSB'] arg_format = ['sbw','sb'] -class trans(base.VarArgsInstruction): +class trans(base.VarArgsInstruction, base.DynFormatInstruction): """ Secret bit register vector transpose. The first destination vector will contain the least significant bits of all source vectors etc. @@ -511,10 +511,22 @@ class trans(base.VarArgsInstruction): code = opcodes['TRANS'] is_vec = lambda self: True def __init__(self, *args): - self.arg_format = ['int'] + ['sbw'] * args[0] + \ - ['sb'] * (len(args) - 1 - args[0]) super(trans, self).__init__(*args) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + n = next(args) + for i in range(n): + yield 'sbw' + next(args) + while True: + try: + yield 'sb' + next(args) + except StopIteration: + break + class bitb(NonVectorInstruction): """ Copy fresh secret random bit to secret bit register. @@ -560,7 +572,7 @@ def add_usage(self, req_node): req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1]) class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, - base.Mergeable): + base.Mergeable, base.DynFormatInstruction): """ Copy private input to secret bit registers bit by bit. The input is read as floating-point number, multiplied by a power of two, rounded to an integer, and then decomposed into bits. @@ -577,11 +589,18 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, code = opcodes['INPUTBVEC'] def __init__(self, *args, **kwargs): - self.arg_format = [] - for x in self.get_arg_tuples(args): - self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3) super(inputbvec, self).__init__(*args, **kwargs) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + yield 'p' + for j in range(n - 3): + yield 'sbw' + yield 'int' + @staticmethod def get_arg_tuples(args): i = 0 @@ -590,10 +609,6 @@ def get_arg_tuples(args): i += args[i] assert i == len(args) - def merge(self, other): - self.args += other.args - self.arg_format += other.arg_format - def add_usage(self, req_node): for x in self.get_arg_tuples(self.args): req_node.increment(('bit', 'input', x[2]), x[0] - 3) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 13619c7f9..38c37a261 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -41,7 +41,7 @@ class bitsn(cls): return cls.types[length] @classmethod def conv(cls, other): - if isinstance(other, cls): + if isinstance(other, cls) and cls.n == other.n: return other elif isinstance(other, MemValue): return cls.conv(other.read()) @@ -246,14 +246,20 @@ def conv_regint_by_bit(cls, n, res, other): assert n == res.n assert n == other.size cls.conv_cint_vec(cint(other, size=other.size), res) + @classmethod + def conv(cls, other): + if isinstance(other, cbits) and cls.n != None and \ + cls.n // cls.unit == other.n // cls.unit: + return other + else: + return super(cbits, cls).conv(other) types = {} def load_int(self, value): - if self.n <= 64: - tmp = regint(value) - elif value == self.long_one(): - tmp = cint(1, size=self.n) - else: - raise CompilerError('loading long integers to cbits not supported') + n_limbs = math.ceil(self.n / self.unit) + tmp = regint(size=n_limbs) + for i in range(n_limbs): + tmp[i].load_int(value % 2 ** self.unit) + value >>= self.unit self.load_other(tmp) def store_in_dynamic_mem(self, address): inst.stmsdci(self, cbits.conv(address)) @@ -1163,14 +1169,14 @@ class cbitfix(object): @classmethod def _new(cls, value): res = cls() + if cls.k < value.unit: + bits = value.bit_decompose(cls.k) + sign = bits[-1] + value += (sign << (cls.k)) * -1 res.v = value return res def output(self): v = self.v - if self.k < v.unit: - bits = self.v.bit_decompose(self.k) - sign = bits[-1] - v += (sign << (self.k)) * -1 inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), cbits(0), cbits(0)) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 9871d97fa..cf2f13ef4 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -403,6 +403,20 @@ def keep_merged_order(instr, n, t): add_edge(last_input[t][1], n) last_input[t][0] = n + def keep_text_order(inst, n): + if inst.get_players() is None: + # switch + for x in list(last_input.keys()): + if isinstance(x, int): + add_edge(last_input[x][0], n) + del last_input[x] + keep_merged_order(instr, n, None) + elif last_input[None][0] is not None: + keep_merged_order(instr, n, None) + else: + for player in inst.get_players(): + keep_merged_order(instr, n, player) + for n,instr in enumerate(block.instructions): outputs,inputs = instr.get_def(), instr.get_used() @@ -427,7 +441,7 @@ def keep_merged_order(instr, n, t): # will be merged if isinstance(instr, TextInputInstruction): - keep_merged_order(instr, n, TextInputInstruction) + keep_text_order(instr, n) elif isinstance(instr, RawInputInstruction): keep_merged_order(instr, n, RawInputInstruction) @@ -479,10 +493,6 @@ def keep_merged_order(instr, n, t): last_print_str = n elif isinstance(instr, PublicFileIOInstruction): keep_order(instr, n, instr.__class__) - elif isinstance(instr, startprivateoutput_class): - keep_order(instr, n, startprivateoutput_class, 2) - elif isinstance(instr, stopprivateoutput_class): - keep_order(instr, n, stopprivateoutput_class, 2) elif isinstance(instr, prep_class): keep_order(instr, n, instr.args[0]) elif isinstance(instr, StackInstruction): diff --git a/Compiler/instructions.py b/Compiler/instructions.py index a85fb25ad..e06797684 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -421,6 +421,10 @@ class use_matmul(base.Instruction): code = base.opcodes['USE_MATMUL'] arg_format = ['int','int','int','int'] + @classmethod + def get_usage(cls, args): + return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i} + class run_tape(base.Instruction): """ Start tape/bytecode file in another thread. @@ -1229,15 +1233,20 @@ def __init__(self, *args, **kwargs): @base.gf2n @base.vectorize class inputmask(base.Instruction): - r""" Load secret $s_i$ with the next input mask for player $p$ and - write the mask on player $p$'s private output. """ + """ Store fresh random input mask(s) in secret register (vector) and clear + register (vector) of the relevant player. + + :param: mask (sint) + :param: mask (cint, player only) + :param: player (int) + """ __slots__ = [] code = base.opcodes['INPUTMASK'] - arg_format = ['sw', 'p'] + arg_format = ['sw', 'cw', 'p'] field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[1]), \ + req_node.increment((self.field_type, 'input', self.args[2]), \ self.get_size()) @base.vectorize @@ -1293,10 +1302,8 @@ class asm_input(base.TextInputInstruction): arg_format = tools.cycle(['sw', 'p']) field_type = 'modp' - def add_usage(self, req_node): - for player in self.args[1::2]: - req_node.increment((self.field_type, 'input', player), \ - self.get_size()) + def get_players(self): + return self.args[1::2] @base.vectorize class inputfix(base.TextInputInstruction): @@ -1305,10 +1312,8 @@ class inputfix(base.TextInputInstruction): arg_format = tools.cycle(['sw', 'int', 'p']) field_type = 'modp' - def add_usage(self, req_node): - for player in self.args[2::3]: - req_node.increment((self.field_type, 'input', player), \ - self.get_size()) + def get_players(self): + return self.args[2::3] @base.vectorize class inputfloat(base.TextInputInstruction): @@ -1322,7 +1327,7 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', player), \ 4 * self.get_size()) -class inputmixed_base(base.TextInputInstruction): +class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction): __slots__ = [] field_type = 'modp' # the following has to match TYPE: (N_DEST, N_PARAM) @@ -1341,22 +1346,30 @@ def __init__(self, name, *args): type_id = self.type_ids[name] super(inputmixed_base, self).__init__(type_id, *args) - @property - def arg_format(self): - for i in self.bases(): - t = self.args[i] - yield 'int' + @classmethod + def dynamic_arg_format(self, args): + yield 'int' + for i, t in self.bases(iter(args)): for j in range(self.types[t][0]): yield 'sw' for j in range(self.types[t][1]): yield 'int' yield self.player_arg_type + yield 'int' - def bases(self): + @classmethod + def bases(self, args): i = 0 - while i < len(self.args): - yield i - i += sum(self.types[self.args[i]]) + 2 + while True: + try: + t = next(args) + except StopIteration: + return + yield i, t + n = sum(self.types[t]) + i += n + 2 + for j in range(n + 1): + next(args) @base.vectorize class inputmixed(inputmixed_base): @@ -1380,13 +1393,16 @@ class inputmixed(inputmixed_base): player_arg_type = 'p' def add_usage(self, req_node): - for i in self.bases(): - t = self.args[i] + for i, t in self.bases(iter(self.args)): player = self.args[i + sum(self.types[t]) + 1] n_dest = self.types[t][0] req_node.increment((self.field_type, 'input', player), \ n_dest * self.get_size()) + def get_players(self): + for i, t in self.bases(iter(self.args)): + yield self.args[i + sum(self.types[t]) + 1] + @base.vectorize class inputmixedreg(inputmixed_base): """ Store private input in secret registers (vectors). The input is @@ -1412,6 +1428,9 @@ def add_usage(self, req_node): # player 0 as proxy req_node.increment((self.field_type, 'input', 0), float('inf')) + def get_players(self): + pass + @base.gf2n @base.vectorize class rawinput(base.RawInputInstruction, base.Mergeable): @@ -1433,7 +1452,23 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', player), \ self.get_size()) -class inputpersonal(base.Instruction, base.Mergeable): +class personal_base(base.Instruction, base.Mergeable): + __slots__ = [] + field_type = 'modp' + + def __init__(self, *args): + super(personal_base, self).__init__(*args) + for i in range(0, len(args), 4): + assert args[i + 2].size == args[i] + assert args[i + 3].size == args[i] + + def add_usage(self, req_node): + for i in range(0, len(self.args), 4): + player = self.args[i + 1] + req_node.increment((self.field_type, 'input', player), \ + self.args[i]) + +class inputpersonal(personal_base): """ Private input from cint. :param: vector size (int) @@ -1445,19 +1480,39 @@ class inputpersonal(base.Instruction, base.Mergeable): __slots__ = [] code = base.opcodes['INPUTPERSONAL'] arg_format = tools.cycle(['int','p','sw','c']) - field_type = 'modp' + +class privateoutput(personal_base): + """ Private input from cint. + + :param: vector size (int) + :param: player (int) + :param: destination (cint) + :param: source (sint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['PRIVATEOUTPUT'] + arg_format = tools.cycle(['int','p','cw','s']) + +class sendpersonal(base.Instruction, base.Mergeable): + """ Private input from cint. + + :param: vector size (int) + :param: destination player (int) + :param: destination (cint) + :param: source player (int) + :param: source (cint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['SENDPERSONAL'] + arg_format = tools.cycle(['int','p','cw','p','c']) def __init__(self, *args): - super(inputpersonal, self).__init__(*args) - for i in range(0, len(args), 4): + super(sendpersonal, self).__init__(*args) + for i in range(0, len(args), 5): assert args[i + 2].size == args[i] - assert args[i + 3].size == args[i] - - def add_usage(self, req_node): - for i in range(0, len(self.args), 4): - player = self.args[i + 1] - req_node.increment((self.field_type, 'input', player), \ - self.args[i]) + assert args[i + 4].size == args[i] @base.gf2n @base.vectorize @@ -1789,27 +1844,6 @@ class floatoutput(base.PublicFileIOInstruction): code = base.opcodes['FLOATOUTPUT'] arg_format = ['p','c','c','c','c'] -@base.gf2n -@base.vectorize -class startprivateoutput(base.Instruction): - r""" Initiate private output to $n$ of $s_j$ via $s_i$. """ - __slots__ = [] - code = base.opcodes['STARTPRIVATEOUTPUT'] - arg_format = ['sw','s','p'] - field_type = 'modp' - - def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[2]), \ - self.get_size()) - -@base.gf2n -@base.vectorize -class stopprivateoutput(base.Instruction): - r""" Previously iniated private output to $n$ via $c_i$. """ - __slots__ = [] - code = base.opcodes['STOPPRIVATEOUTPUT'] - arg_format = ['cw','c','p'] - @base.vectorize class rand(base.Instruction): """ Store insecure random value of specified length in clear integer @@ -2210,7 +2244,8 @@ def get_used(self): @base.gf2n @base.vectorize -class dotprods(base.VarArgsInstruction, base.DataInstruction): +class dotprods(base.VarArgsInstruction, base.DataInstruction, + base.DynFormatInstruction): """ Dot product of secret registers (vectors). Note that the vectorized version works element-wise. @@ -2238,31 +2273,29 @@ def __init__(self, *args): flat_args += [x, y] base.Instruction.__init__(self, *flat_args) - @property - def arg_format(self): + @classmethod + def dynamic_arg_format(self, args): field = 'g' if self.is_gf2n() else '' - for i in self.bases(): - yield 'int' + yield 'int' + for i, n in self.bases(args): yield 's' + field + 'w' - for j in range(self.args[i] - 2): + for j in range(n - 2): yield 's' + field + yield 'int' - gf2n_arg_format = arg_format - - def bases(self): - i = 0 - while i < len(self.args): - yield i - i += self.args[i] + @property + def gf2n_arg_format(self): + return self.arg_format() def get_repeat(self): - return sum(self.args[i] // 2 for i in self.bases()) * self.get_size() + return sum(self.args[i] // 2 + for i, n in self.bases(iter(self.args))) * self.get_size() def get_def(self): - return [self.args[i + 1] for i in self.bases()] + return [self.args[i + 1] for i, n in self.bases(iter(self.args))] def get_used(self): - for i in self.bases(): + for i, n in self.bases(iter(self.args)): for reg in self.args[i + 2:i + self.args[i]]: yield reg diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index fb2a67b89..d6c647add 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -105,6 +105,7 @@ MATMULSM = 0xAB, CONV2DS = 0xAC, CHECK = 0xAF, + PRIVATEOUTPUT = 0xAD, # Data access TRIPLE = 0x50, BIT = 0x51, @@ -128,6 +129,7 @@ INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, INPUTPERSONAL = 0xF5, + SENDPERSONAL = 0xF6, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -364,6 +366,7 @@ class GF2N_Instruction(instruction_cls): arg_format = copy.deepcopy(instruction_cls.arg_format) reformat(arg_format) + @classmethod def is_gf2n(self): return True @@ -505,8 +508,12 @@ def expand_merged(self, skip): for arg in self.args: try: new_regs.append(type(arg)(size=size)) - except: + except TypeError: break + except: + print([call[0][0].size for call in self.calls]) + raise + assert len(new_regs) > 1 base = 0 for call in self.calls: for new_reg, reg in zip(new_regs[1:], call[0][1:]): @@ -854,6 +861,7 @@ def has_var_args(self): def is_vec(self): return False + @classmethod def is_gf2n(self): return False @@ -902,6 +910,10 @@ def get_new_args(self, size, subs): new_args.append(arg) return new_args + @staticmethod + def get_usage(args): + return {} + # String version of instruction attempting to replicate encoded version def __str__(self): @@ -949,9 +961,18 @@ def __init__(self, f): if name == 'cisc': arg_format = itertools.chain(['str'], itertools.repeat('int')) else: - arg_format = itertools.repeat('int') - self.args = [ArgFormats[next(arg_format)](f) - for i in range(n_args)] + def arg_iter(): + i = 0 + while True: + try: + yield self.args[i].i + except AttributeError: + yield None + i += 1 + arg_format = t.dynamic_arg_format(arg_iter()) + self.args = [] + for i in range(n_args): + self.args.append(ArgFormats[next(arg_format)](f)) def __str__(self): name = self.type.__name__ @@ -963,6 +984,9 @@ def __str__(self): res += ', '.join(str(arg) for arg in self.args) return res + def get_usage(self): + return self.type.get_usage(self.args) + class VarArgsInstruction(Instruction): def has_var_args(self): return True @@ -974,6 +998,26 @@ class VectorInstruction(Instruction): def get_code(self): return super(VectorInstruction, self).get_code(len(self.args[0])) +class DynFormatInstruction(Instruction): + __slots__ = [] + + @property + def arg_format(self): + return self.dynamic_arg_format(iter(self.args)) + + @classmethod + def bases(self, args): + i = 0 + while True: + try: + n = next(args) + except StopIteration: + return + yield i, n + i += n + for j in range(n - 1): + next(args) + ### ### Basic arithmetic ### @@ -1072,6 +1116,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction): """ Input from text file or stdin """ __slots__ = [] + def add_usage(self, req_node): + for player in self.get_players(): + req_node.increment((self.field_type, 'input', player), \ + self.get_size()) + ### ### Data access instructions ### diff --git a/Compiler/library.py b/Compiler/library.py index 4f6c2de16..3f31499b0 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -223,7 +223,7 @@ def crash(condition=None): if isinstance(condition, localint): # allow crash on local values condition = condition._v - if condition == None: + if condition is None: condition = regint(1) instructions.crash(regint.conv(condition)) @@ -284,8 +284,8 @@ def get_arg(): def make_array(l): if isinstance(l, program.Tape.Register): - res = Array(1, type(l)) - res[0] = l + res = Array(len(l), type(l)) + res[:] = l else: l = list(l) res = Array(len(l), type(l[0]) if l else cint) @@ -1032,6 +1032,7 @@ def _(i): state = tuplify(initializer()) k = 0 block = get_block() + assert not isinstance(n_loops, int) or n_loops > 0 pre = copy.copy(loop_body.__globals__) while (not util.is_constant(n_loops) or k < n_loops) \ and (len(get_block()) < budget or k == 0) \ @@ -1211,7 +1212,13 @@ def decorator(loop_body): if t != regint: raise CompilerError('Not implemented for other than regint') args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') - state = tuple(initializer()) + state = initializer() + if len(state) == 0: + state_type = cint + elif isinstance(state, (tuple, list)): + state_type = type(state[0]) + else: + state_type = type(state) def f(inc): base = args[get_arg()][0] if not util.is_constant(thread_rounds): @@ -1224,8 +1231,7 @@ def f(inc): if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint, \ args[get_arg()].address + 2) - mem_state = Array(len(state), type(state[0]) \ - if state else cint, args[get_arg()][1]) + mem_state = Array(len(state), state_type, args[get_arg()][1]) @map_reduce_single(n_parallel, thread_rounds + inc, \ initializer, reducer, mem_state) def f(i): @@ -1257,14 +1263,14 @@ def f(i): threads = prog.run_tapes(thread_args) for thread in threads: prog.join_tape(thread) - if state: + if len(state): if thread_rounds: for i in range(n_threads - remainder): - state = reducer(Array(len(state), type(state[0]), \ + state = reducer(Array(len(state), state_type, \ args[remainder + i][1]), state) if remainder: for i in range(remainder): - state = reducer(Array(len(state), type(state[0]).reg_type, \ + state = reducer(Array(len(state), state_type, \ args[i][1]), state) def returner(): return untuplify(state) @@ -1300,6 +1306,39 @@ def summer(i): """ return map_sum(n_threads, None, n_loops, len(types), types) +def map_sum_simple(n_threads, n_loops, type, size): + """ Vectorized multi-threaded sum reduction. The following computes a + 100 sums of ten squares in three threads:: + + @map_sum_simple(3, 10, sint, 100) + def summer(i): + return sint(regint.inc(100, i, 0)) ** 2 + + result = summer() + + :param n_threads: number of threads (int) + :param n_loops: number of loop runs (regint/cint/int) + :param type: return type, must match the return statement + in the loop + :param size: vector size, must match the return statement + in the loop + + """ + initializer = lambda: type(0, size=size) + def summer(*args): + assert len(args) == 2 + args = list(args) + for i in (0, 1): + if isinstance(args[i], tuple): + assert len(args[i]) == 1 + args[i] = args[i][0] + for i in (0, 1): + assert len(args[i]) == size + if isinstance(args[i], Array): + args[i] = args[i][:] + return args[0] + args[1] + return map_reduce(n_threads, 1, n_loops, initializer, summer) + def tree_reduce_multithread(n_threads, function, vector): inputs = vector.Array(len(vector)) inputs.assign_vector(vector) diff --git a/Compiler/ml.py b/Compiler/ml.py index 5c4664be8..c521934fe 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -223,6 +223,7 @@ class Layer: thetas = lambda self: () debug_output = False back_batch_size = 128 + print_random_update = False @property def shape(self): @@ -254,6 +255,9 @@ def forward(self, batch=None, training=None): def __str__(self): return type(self).__name__ + str(self._Y.sizes) + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.Y.sizes) + class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None output_weights = lambda *args: None @@ -459,6 +463,10 @@ def __init__(self, N, d_out, approx=False, debug=False): self.debug = debug self.true_X = sfix.Array(N) + def __repr__(self): + return '%s(%s, %s, approx=%s)' % \ + (type(self).__name__, self.N, self.d_out, self.approx) + def _forward(self, batch): N = len(batch) d_out = self.X.sizes[1] @@ -609,10 +617,11 @@ def backward_params(self, f_schur_Y, batch): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) + A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address) + B = sfix.Matrix(self.N, self.d_in, address=self.X.address) + @multithread(self.n_threads, self.d_in) def _(base, size): - A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) mp = B.direct_trans_mul(A, reduce=False, indices=(regint.inc(size, base), batch.get_vector(), @@ -622,16 +631,24 @@ def _(base, size): progress('nabla W (matmul)') - if self.d_in * self.d_out < 200000: - print('reduce at once') - @multithread(self.n_threads, self.d_in * self.d_out) - def _(base, size): - self.nabla_W.assign_vector( - tmp.get_vector(base, size).reduce_after_mul(), base=base) - else: - @for_range_opt(self.d_in) - def _(i): - self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul() + @multithread(self.n_threads, self.d_in * self.d_out, + max_size=get_program().budget) + def _(base, size): + self.nabla_W.assign_vector( + tmp.get_vector(base, size).reduce_after_mul(), base=base) + + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % self.d_in + j = regint.get_random(64) % self.d_out + print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s', + str(self.nabla_W), i, j, tmp[i][j].v.reveal(), + self.nabla_W[i][j].reveal(), + A.get_column(j).reveal(), + B.get_column_by_row_indices( + batch.get_vector(), i).reveal()) + print_ln('batch=%s B=%s', batch, + [self.X[bi][0][i].reveal() for bi in batch]) progress('nabla W') @@ -699,6 +716,7 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): self.d_in = d_in self.d_out = d_out self.d = d + self.activation = activation self.X = MultiArray([N, d, d_in], sfix) self.Y = MultiArray([N, d, d_out], sfix) @@ -721,12 +739,17 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): else: self.f_input = self.Y + def __repr__(self): + return '%s(%s, %s, %s, activation=%s)' % \ + (type(self).__name__, self.N, self.d_in, + self.d_out, repr(self.activation)) + def reset(self): d_in = self.d_in d_out = self.d_out r = math.sqrt(6.0 / (d_in + d_out)) print('Initializing dense weights in [%f,%f]' % (-r, r)) - self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size())) + self.W.randomize(-r, r) self.b.assign_all(0) def input_from(self, player, raw=False): @@ -820,6 +843,12 @@ def _(base, size): regint.inc(self.d_in))), base) + if self.print_random_update: + print_ln('backward %s', self) + index = regint.get_random(64) % self.nabla_X.total_size() + print_ln('%s nabla_X at %s: %s', str(self.nabla_X), + index, self.nabla_X.to_array()[index].reveal()) + progress('nabla X') self.backward_params(f_schur_Y, batch=batch) @@ -890,6 +919,10 @@ def __init__(self, N, d1, d2=1, alpha=0.5): self.alpha = alpha self.B = MultiArray([N, d1, d2], sint) + def __repr__(self): + return '%s(%s, %s, alpha=%s)' % \ + (type(self).__name__, self.N, self.d1, self.alpha) + def forward(self, batch, training=False): if training: n_bits = -math.log(self.alpha, 2) @@ -1022,6 +1055,7 @@ class MaxPool(NoVariableLayer): def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), padding='VALID'): assert len(shape) == 4 + assert min(shape) > 0, shape for x in strides, ksize: for i in 0, 3: assert x[i] == 1 @@ -1033,12 +1067,18 @@ def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), self.Y = Tensor(output_shape, sfix) self.strides = strides self.ksize = ksize + self.padding = padding self.nabla_X = Tensor(shape, sfix) self.nabla_Y = Tensor(output_shape, sfix) self.N = shape[0] self.comparisons = MultiArray([self.N, self.X.sizes[3], ksize[1] * ksize[2]], sint) + def __repr__(self): + return '%s(%s, strides=%s, ksize=%s, padding=%s)' % \ + (type(self).__name__, self.X.sizes, self.strides, + self.ksize, self.padding) + def _forward(self, batch): def process(pool, bi, k, i, j): def m(a, b): @@ -1165,7 +1205,7 @@ def _(base, size): self.Y[batch[0]].assign_vector(tmp, base) class FusedBatchNorm(Layer): - """ Fixed-point fused batch normalization layer. + """ Fixed-point fused batch normalization layer (inference only). :param shape: input/output shape (tuple/list of four int) """ @@ -1192,6 +1232,153 @@ def _(i, j): self.X[batch[0]][i][j].get_vector() * self.weights.get_vector() + self.bias.get_vector()) +class BatchNorm(Layer): + """ Fixed-point batch normalization layer. + + :param shape: input/output shape (tuple/list of four int) + :param approx: use approximate square root + + """ + thetas = lambda self: (self.weights, self.bias) + nablas = lambda self: (self.nabla_weights, self.nabla_bias) + + def __init__(self, shape, approx=True, args=None): + assert len(shape) in (2, 3, 4) + if len(shape) == 4: + shape = [shape[0], shape[1] * shape[2], shape[3]] + elif len(shape) == 2: + shape = [shape[0], 1, shape[1]] + tensors = (Tensor(shape, sfix) for i in range(4)) + self.X, self.Y, self.nabla_X, self.nabla_Y = tensors + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.var, self.mu, self.weights, self.bias = arrays + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.mu_hat, self.var_hat, self.nabla_weights, self.nabla_bias = arrays + self.epsilon = 2 ** (-sfix.f + 1) + self.momentum = 0.1 + if args != None: + approx = 'precisebn' not in args + self.approx = approx + if approx: + print('Approximate square root inverse in batch normalization') + self.InvertSqrt = mpc_math.InvertSqrt + else: + print('Precise square root inverse in batch normalization') + self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x) + + def __repr__(self): + return '%s(%s, approx=%s)' % \ + (type(self).__name__, self.X.sizes, self.approx) + + def reset(self): + self.bias.assign_all(0) + self.weights.assign_all(1) + self.mu_hat.assign_all(0) + self.var_hat.assign_all(0) + + def _output(self, batch, mu, var): + factor = sfix.Array(len(mu)) + factor[:] = self.InvertSqrt(var[:] + self.epsilon) + @for_range_opt_multithread(self.n_threads, + [len(batch), self.X.sizes[1]]) + def _(i, j): + tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:] + self.Y[i][j][:] = self.bias[:] + tmp + + def forward(self, batch, training=False): + if training: + d = self.X.sizes[1] + d_in = self.X.sizes[2] + s = sfix.Array(d_in) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.X[batch[i]][j].get_vector()) + s.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.mu.assign_vector( + s.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + item = self.X[batch[i]][j].get_vector() + return ((item - self.mu[:]) ** 2) + self.var.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.var.assign_vector( + self.var.get_vector(base, size) / (len(batch) * d - 1), + base) + for x, y, in (self.mu_hat, self.mu), (self.var_hat, self.var): + x[:] = self.momentum * y[:] + (1 - self.momentum) * x[:] + self._output(batch, self.mu, self.var) + if self.print_random_update: + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.mu, self.var: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', + str(self.Y), i, j, k, self.X[i][j][k].reveal(), + self.Y[i][j][k].reveal()) + else: + self._output(batch, self.mu_hat, self.var_hat) + + def backward(self, batch, compute_nabla_X=True): + factor = Array.create_from( + self.InvertSqrt(self.var[:] + self.epsilon)) + mynYf = self.X.same_shape() + gamnY = self.X.same_shape() + gamnYd = self.X.same_shape() + nYdf = self.X.same_shape() + d = self.X.sizes[1] + d_in = self.X.sizes[2] + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + tmp = self.weights[:] * self.nabla_Y[i][j][:] + gamnY[i][j] = tmp + gamnYd[i][j] = tmp * (self.X[i][j][:] - self.mu[:]) + mynYf[i][j] = tmp * factor[:] + nYdf[i][j] = self.nabla_Y[i][j][:] * \ + (self.X[i][j][:] - self.mu[:]) * factor[:] + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.nabla_Y[i][j][:]) + self.nabla_bias.assign(_()) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (nYdf[i][j]) + self.nabla_weights.assign(_()) + factor3 = Array.create_from(factor[:] ** 3) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (mynYf[i][j]) + s1 = Array.create_from(_()) + @multithread(self.n_threads, len(s1)) + def _(base, size): + s1.assign_vector(s1.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (gamnYd[i][j][:] * factor3[:]) + s2 = Array.create_from(_()) + @multithread(self.n_threads, len(s2)) + def _(base, size): + s2.assign_vector( + s2.get_vector(base, size) / (len(batch) * d - 1), base) + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + self.nabla_X[i][j][:] = mynYf[i][j][:] \ + - s1[:] - (self.X[i][j][:] - self.mu[:]) * s2[:] + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.nabla_bias, self.nabla_weights: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k, + self.nabla_Y[i][j][k].reveal(), + self.nabla_X[i][j][k].reveal()) + class QuantBase(object): bias_before_reduction = True @@ -1298,6 +1485,8 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, self.padding.append(pad_total // 2) elif padding == 'VALID': self.padding = [0, 0] + elif isinstance(padding, int): + self.padding = [padding, padding] else: self.padding = padding @@ -1323,6 +1512,12 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, assert(len(output_shape) == 4) assert(len(weight_shape) == 4) + def __repr__(self): + return '%s(%s, %s, %s, %s, %s, padding=%s, tf_weight_format=%s)' % \ + (type(self).__name__, self.X.sizes, self.weight_shape, + self.bias_shape, self.Y.sizes, self.stride, repr(self.padding), + self.tf_weight_format) + def input_from(self, player, raw=False): self.input_params_from(player) self.weights.input_from(player, budget=100000, raw=raw) @@ -1545,20 +1740,20 @@ def _(i, j): self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i) if compute_nabla_X: - assert tuple(self.padding) == (0, 0) assert tuple(self.stride) == (1, 1) reverse_weights = MultiArray( [n_channels_in, weights_h, weights_w, n_channels_out], sfix) - @for_range(n_channels_out) - def _(i): + @for_range_opt_multithread(self.n_threads, n_channels_in) + def _(l): @for_range(weights_h) def _(j): @for_range(weights_w) def _(k): - @for_range(n_channels_in) - def _(l): - reverse_weights[l][weights_h-j-1][k][i] = \ - self.weights[i][j][weights_w-k-1][l] + addresses = regint.inc(n_channels_out, + self.weights[0][j][weights_w-k-1].get_address(l), + reduce(operator.mul, self.weights.sizes[1:])) + reverse_weights[l][weights_h-j-1][k].assign_vector( + self.weights.value_type.load_mem(addresses)) padded_w = inputs_w + 2 * padding_w padded_h = inputs_h + 2 * padding_h if padding_h or padding_w: @@ -1579,14 +1774,16 @@ def _(i, j): unreduced_sfix._new(res).reduce_after_mul(), i, None, None, j) if padding_h or padding_w: - @for_range(N) + @for_range_opt_multithread(self.n_threads, N) def _(i): @for_range(inputs_h) def _(j): @for_range(inputs_w) def _(k): + jj = j + padding_w + kk = k + padding_w self.nabla_X[i][j][k].assign_vector( - output[i][j][k].get_vector()) + output[i][jj][kk].get_vector()) if self.debug_output: @for_range(len(batch)) @@ -1806,6 +2003,7 @@ def __init__(self, report_loss=None): self.report_loss = report_loss self.X_by_label = None self.print_update_average = False + self.print_random_update = False self.print_losses = False self.print_loss_reduction = False self.i_epoch = MemValue(0) @@ -1846,6 +2044,7 @@ def reset(self): def batch_for(self, layer, batch): if layer in (self.layers[0], self.layers[-1]): + assert not isinstance(layer, BatchNorm) return batch else: batch = regint.Array(len(batch)) @@ -1876,6 +2075,21 @@ def forward(self, N=None, batch=None, keep_intermediate=True, if i != len(self.layers) - 1 or run_last: layer.forward(batch=self.batch_for(layer, batch), training=training) + if self.print_random_update: + print_ln('forward layer %s', layer) + l = min(100, layer.Y[i].total_size()) + i = regint.get_random(64) % len(batch) + if l < 100: + j = 0 + else: + j = regint.get_random(64) % \ + (layer.Y[i].total_size() - l) + print_ln('forward layer %s at (%s, %s): %s', layer, i, j, + layer.Y[i].to_array().get_vector(j, l).reveal()) + i = regint.get_random(64) % layer.Y[0].total_size() + print_ln('forward layer %s vertical at %s: %s', layer, i, + [layer.Y[j].to_array()[i].reveal() + for j in range(len(batch))]) if self.time_layers: stop_timer(100 + i) break_point() @@ -1979,7 +2193,11 @@ def _(j): label * n) self.forward(batch=batch, training=True) self.backward(batch=batch) + if self.time_layers: + start_timer(1000) self.update(i, batch=batch) + if self.time_layers: + stop_timer(1000) loss_sum.iadd(self.layers[-1].l) if self.print_loss_reduction: before = self.layers[-1].average_loss(N) @@ -2070,6 +2288,8 @@ def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, if 'nomom' in program.args: self.momentum = 0 self.print_losses = 'print_losses' in program.args + self.print_random_update = 'print_random_update' in program.args + Layer.print_random_update = self.print_random_update self.time_layers = 'time_layers' in program.args self.revealing_correctness = not 'no_acc' in program.args self.layers[-1].compute_loss = not 'no_loss' in program.args @@ -2099,6 +2319,16 @@ def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, print_ln('loss %s', self.layers[-1].l.reveal()) self.output_weights() return + if 'bench10' in program.args or 'bench1' in program.args: + n = 1 if 'bench1' in program.args else 10 + print('benchmarking %s iterations' % n) + @for_range(n) + def _(i): + batch = Array.create_from(regint.inc(batch_size)) + self.forward(batch=batch, training=True) + self.backward(batch=batch) + self.update(0, batch=batch) + return @for_range(n_runs) def _(i): if not acc_first: @@ -2115,6 +2345,7 @@ def _(i): cfix(self.n_correct, k=63, f=31) / n_trained, self.n_correct, n_trained) if test_X and test_Y: + print('use test set') n_test = len(test_Y) n_correct, loss = self.reveal_correctness(test_X, test_Y, acc_batch_size) @@ -2211,7 +2442,8 @@ def _(base, size): util.max, abs_g.get_vector()) scale = MemValue(sfix._new(library.AppRcr( max_g.v, max_g.k, max_g.f, simplex_flag=True))) - @multithread(self.n_threads, m.total_size()) + @multithread(self.n_threads, m.total_size(), + max_size=get_program().budget) def _(base, size): m_part = m.get_vector(base, size) v_part = v.get_vector(base, size) @@ -2333,20 +2565,33 @@ def _(i): print_ln_if((x > limit) + (x < -limit), 'theta epoch=%s %s index=%s %s', i_epoch.read(), str(theta), i, x) - index = regint.get_random(64) % len(a) - print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index, - aa[1][index], aa[0][index], aa[2][index]) + if self.print_random_update: + print_ln('update') + l = min(100, nabla.total_size()) + if l < 100: + index = 0 + else: + index = regint.get_random(64) % (nabla.total_size() - l) + print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), + index, nabla.to_array().get_vector(index, l).reveal(), + delta_theta.to_array().get_vector(index, l).reveal(), + theta.to_array().get_vector(index, l).reveal()) self.gamma.imul(1 - 10 ** - 6) def apply_padding(input_shape, kernel_size, strides, padding): + if isinstance(padding, int): + input_shape = [x + 2 * padding for x in input_shape] + padding = 'valid' if padding == 'valid': - return (input_shape[0] - kernel_size[0] + 1) // strides[0], \ + res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \ (input_shape[1] - kernel_size[1] + 1) // strides[1], + assert min(res) > 0, (input_shape, kernel_size, strides, padding) + return res elif padding == 'same': - return (input_shape[1]) // strides[0], \ - (input_shape[2]) // strides[1], + return (input_shape[0]) // strides[0], \ + (input_shape[1]) // strides[1], else: - raise Exception('invalid padding: ' + padding) + raise Exception('invalid padding: %s' % padding) class keras: class layers: @@ -2354,7 +2599,7 @@ class layers: Dense = lambda *args, **kwargs: ('dense', args, kwargs) def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', - activation=None): + activation=None, input_shape=None): return 'conv2d', {'filters': filters, 'kernel_size': kernel_size, 'strides': strides, 'padding': padding, 'activation': activation} @@ -2369,6 +2614,13 @@ def Dropout(rate): raise Exception('rate needs to be a power of two') return 'dropout', rate + def Activation(activation): + assert(activation == 'relu') + return activation, + + def BatchNormalization(): + return 'batchnorm', + class optimizers: SGD = lambda *args, **kwargs: ('sgd', args, kwargs) Adam = lambda *args, **kwargs: ('adam', args, kwargs) @@ -2383,12 +2635,25 @@ def __init__(self, layers): def compile(self, optimizer): self.optimizer = optimizer + def compile_by_args(self, program): + if 'adam' in program.args: + self.optimizer = 'adam', [], {} + elif 'amsgrad' in program.args: + self.optimizer = 'adam', [], {'amsgrad': True} + else: + self.optimizer = 'sgd', [], {} + @property def trainable_variables(self): if self.opt == None: raise Exception('need to run build() or fit() first') return list(self.opt.thetas) + def summary(self): + sizes = [var.total_size() for var in self.trainable_variables] + print(sizes) + print('Trainable params:', sum(sizes)) + def build(self, input_shape, batch_size=128): data_input_shape = input_shape if self.opt != None and \ @@ -2415,12 +2680,11 @@ def build(self, input_shape, batch_size=128): if i == len(self.layers) - 1: if layer[2].get('activation', 'softmax') in \ ('softmax', 'sigmoid'): - del layer[2]['activation'] + layer[2].pop('activation', None) layers.append(Dense(N, n_units, layer[1][0], **layer[2])) + input_shape = layers[-1].Y.sizes elif name == 'conv2d': - if len(layers) != 0: - input_shape = layers[-1].Y.sizes input_shape = list(input_shape) + \ [1] * (4 - len(input_shape)) print (layer[1]) @@ -2437,9 +2701,13 @@ def build(self, input_shape, batch_size=128): output_shape = [batch_size] + list( apply_padding(input_shape[1:3], kernel_size, strides, padding)) + [filters] + padding = padding.upper() if isinstance(padding, str) \ + else padding layers.append(FixConv2d(input_shape, weight_shape, (filters,), output_shape, - strides, padding.upper())) + strides, padding)) + input_shape = output_shape + print('conv output shape', output_shape) elif name == 'maxpool': pool_size = layer[1]['pool_size'] strides = layer[1]['strides'] @@ -2450,16 +2718,23 @@ def build(self, input_shape, batch_size=128): strides = (strides, strides) if strides == None: strides = pool_size - layers.append(MaxPool(layers[-1].Y.sizes, + layers.append(MaxPool(input_shape, [1] + list(strides) + [1], [1] + list(pool_size) + [1], - padding.upper())) + padding)) + input_shape = layers[-1].Y.sizes elif name == 'dropout': layers.append(Dropout(batch_size, reduce( operator.mul, layers[-1].Y.sizes[1:]), alpha=layer[1])) + input_shape = layers[-1].Y.sizes elif name == 'flatten': pass + elif name == 'relu': + layers.append(Relu(layers[-1].Y.sizes)) + elif name == 'batchnorm': + input_shape = layers[-1].Y.sizes + layers.append(BatchNorm(layers[-1].Y.sizes)) else: raise Exception(layer[0] + ' not supported') if layers[-1].d_out == 1: diff --git a/Compiler/oram.py b/Compiler/oram.py index 443d826cd..543fc4aab 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1493,6 +1493,7 @@ def f(i): self.l[i] = [0] * self.elements_per_block time() print_ln('packed ORAM init %s/%s', i, real_init_rounds) + print_ln('packed ORAM init done') print('index initialized, size', size) def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple diff --git a/Compiler/program.py b/Compiler/program.py index 5dad8e516..366723304 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -580,10 +580,19 @@ def disable_memory_warnings(self): @staticmethod def read_tapes(schedule): + m = re.search(r'([^/]*)\.mpc', schedule) + if m: + schedule = m.group(1) if not os.path.exists(schedule): schedule = 'Programs/Schedules/%s.sch' % schedule - lines = open(schedule).readlines() + try: + lines = open(schedule).readlines() + except FileNotFoundError: + print('%s not found, have you compiled the program?' % schedule, + file=sys.stderr) + sys.exit(1) + for tapename in lines[2].split(' '): yield tapename.strip() diff --git a/Compiler/types.py b/Compiler/types.py index 0063fdc16..1dbe1f909 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1675,6 +1675,13 @@ def output(self): __ne__ = lambda self, other: localint(self._v != other) class personal(Tape._no_truth): + """ Value known to one player. Supports operations with public + values and personal values known to the same player. Can be used + with :py:func:`~Compiler.library.print_ln_to`. + + :param player: player (int) + :param value: cleartext value (cint, cfix, cfloat) or array thereof + """ def __init__(self, player, value): assert value is not NotImplemented assert not isinstance(value, _secret) @@ -1685,8 +1692,24 @@ def __init__(self, player, value): self._v = value def binary_output(self): + """ Write binary output to + ``Player-Data/Binary-Output-P-`` if + supported by underlying type. Player must be known at compile time.""" self._v.binary_output(self.player) + def reveal_to(self, player): + """ Pass personal value to another player. """ + if isinstance(self._v, Array): + source = self._v[:] + else: + source = self._v + source = cint.conv(source) + res = cint(size=source.size) + sendpersonal(source.size, player, res, self.player, source) + if isinstance(self._v, Array): + res = Array.create_from(res) + return personal(player, res) + def bit_decompose(self, length): return [personal(self.player, x) for x in self._v.bit_decompose(length)] @@ -1858,8 +1881,13 @@ def get_random_inverse(cls): @vectorized_classmethod @set_instruction_type def get_random_input_mask_for(cls, player): - res = cls() - inputmask(res, player) + """ Secret random input mask according to security model. + + :return: mask (sint), mask (personal cint) + :param size: vector size (int, default 1) + """ + res = cls(), personal(player, cls.clear_type()) + inputmask(res[0], res[1]._v, player) return res @classmethod @@ -2071,15 +2099,13 @@ def reveal(self): @set_instruction_type def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result written to ``Player-Data/Private-Output-P`` :param player: int - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` """ - masked = self.__class__() - res = personal(player, self.clear_type()) - startprivateoutput(masked, self, player) - stopprivateoutput(res._v, masked.reveal(), player) + mask = self.get_random_input_mask_for(player) + masked = self + mask[0] + res = personal(player, masked.reveal() - mask[1]) return res @@ -2633,21 +2659,20 @@ def raw_mod2m(self, m): @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result potentially written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. - :param player: public integer (int/regint/cint): - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :param player: public integer (int/regint/cint) + :returns: :py:class:`personal` """ - if not util.is_constant(player) or self.size > 1: + if not util.is_constant(player): secret_mask = sint() player_mask = cint() inputmaskreg(secret_mask, player_mask, regint.conv(player)) return personal(player, (self + secret_mask).reveal() - player_mask) else: - return super(sint, self).reveal_to(player) + res = personal(player, self.clear_type()) + privateoutput(self.size, player, res._v, self) + return res def private_division(self, divisor, active=True, dividend_length=None, divisor_length=None): @@ -4366,12 +4391,9 @@ def multipliable(v, k, f, size): def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Raw representation possibly written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. :param player: public integer (int/regint/cint) - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` """ return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) @@ -5221,6 +5243,9 @@ def __setitem__(self, index, value): return self.assign(value, addresses) self._store(value, self.get_address(index)) + def to_array(self): + return self + def get_sub(self, start, stop=None): if stop is None: stop = start @@ -5471,6 +5496,10 @@ def shuffle(self): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) + def randomize(self, *args): + """ Randomize according to data type. """ + self.assign_vector(self.value_type.get_random(*args, size=len(self))) + def reveal(self): """ Reveal the whole array. @@ -5596,6 +5625,9 @@ def __len__(self): def __iter__(self): return (self[i] for i in range(len(self))) + def to_array(self): + return Array(self.total_size(), self.value_type, address=self.address) + def assign_all(self, value): """ Assign the same value to all entries. @@ -5958,6 +5990,7 @@ def direct_mul_trans(self, other, reduce=True, indices=None): """ assert len(self.sizes) == 2 assert len(other.sizes) == 2 + assert other.address != None if indices is None: assert self.sizes[1] == other.sizes[1] indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] @@ -6145,6 +6178,16 @@ def diag(self): n = self.sizes[0] return self.array.get(regint.inc(n, 0, n + 1)) + def randomize(self, *args): + """ Randomize according to data type. """ + if self.total_size() < program.options.budget: + self.assign_vector( + self.value_type.get_random(*args, size=self.total_size())) + else: + @library.for_range(self.sizes[0]) + def _(i): + self[i].randomize(*args) + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6251,6 +6294,22 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + def get_column(self, index): + """ Get column as vector. + + :param index: regint/cint/int + """ + assert self.value_type.n_elements() == 1 + addresses = regint.inc(self.sizes[0], self.address + index, + self.sizes[1]) + return self.value_type.load_mem(addresses) + + def get_column_by_row_indices(self, rows, column): + assert self.value_type.n_elements() == 1 + addresses = rows * self.sizes[1] + \ + regint.inc(len(rows), self.address + column, 0) + return self.value_type.load_mem(addresses) + def set_column(self, index, vector): """ Change column. diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 2a4d6b123..20dfb1bb5 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -47,11 +47,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G) } void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) +{ + Rq_Element a(*this); + a.randomize(G); + partial_key_gen(sk, a, G, noise_boost); +} + +void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost) { FHE_PK& PK = *this; - // Generate the main public key - PK.a0.randomize(G); + a0 = a; // b0=a0*s+p*e0 Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation); @@ -77,9 +84,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) mul(es,es,PK.pr); add(PK.Sw_b,PK.Sw_b,es); - // Lowering level as we only decrypt at level 0 - sk.lower_level(); - // bs=bs-p1*s^2 Rq_Element s2; mul(s2,sk,sk); // Mult at level 0 @@ -334,7 +338,7 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, template void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) { - check(*params, pk, pr); + check(*params, pk, FieldD.get_prime()); pk.check_noise(*this); if (decrypt(pk.encrypt(Plaintext_(FieldD)), FieldD) != Plaintext_(FieldD)) diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 72a7ddfa8..30ecc2925 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -150,6 +150,8 @@ class FHE_PK Rq_Element sample_secret_key(PRNG& G); void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1); + void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost = 1); void check_noise(const FHE_SK& sk) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const; diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 8ae6c2885..0de8bb1e9 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -3,6 +3,11 @@ #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" +FHE_Params::FHE_Params(int n_mults) : + FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1) +{ +} + void FHE_Params::set(const Ring& R, const vector& primes) { @@ -24,6 +29,14 @@ void FHE_Params::set_sec(int sec) throw runtime_error("distributed decryption bound is zero"); } +void FHE_Params::set_matrix_dim(int matrix_dim) +{ + assert(matrix_dim > 0); + if (FFTData[0].get_prime() != 0) + throw runtime_error("cannot change matrix dimension after parameter generation"); + this->matrix_dim = matrix_dim; +} + bigint FHE_Params::Q() const { bigint res = FFTData[0].get_prime(); @@ -40,6 +53,7 @@ void FHE_Params::pack(octetStream& o) const Chi.pack(o); Bval.pack(o); o.store(sec_p); + o.store(matrix_dim); } void FHE_Params::unpack(octetStream& o) @@ -52,6 +66,7 @@ void FHE_Params::unpack(octetStream& o) Chi.unpack(o); Bval.unpack(o); o.get(sec_p); + o.get(matrix_dim); } bool FHE_Params::operator!=(const FHE_Params& other) const diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 8ac400839..9407b0ba4 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -26,10 +26,11 @@ class FHE_Params // Data for distributed decryption int sec_p; bigint Bval; + int matrix_dim; public: - FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {} + FHE_Params(int n_mults = 1); int n_mults() const { return FFTData.size() - 1; } @@ -37,6 +38,9 @@ class FHE_Params void set(const vector& primes); void set_sec(int sec); + void set_matrix_dim(int matrix_dim); + int get_matrix_dim() const { return matrix_dim; } + const vector& FFTD() const { return FFTData; } const bigint& p0() const { return FFTData[0].get_prime(); } diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index c6e294a63..7c46a74fe 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2) template <> int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, FFT_Data& FTD, bool round_up) + FHE_Params& params, FFT_Data& FTD, bool round_up, int n) { int m = 1024; int lgp = plaintext_length; @@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec, while (true) { tmp_params = params; - SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec, + SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec, numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params); bigint p1 = 2 * p * m, p0 = p; while (nb.min_p0(params.n_mults() > 0, p1) > p0) @@ -89,14 +89,14 @@ int generate_semi_setup(int plaintext_length, int sec, template <> int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, P2Data& P2D, bool round_up) + FHE_Params& params, P2Data& P2D, bool round_up, int n) { if (params.n_mults() > 0) throw runtime_error("only implemented for 0-level BGV"); gf2n_short::init_field(plaintext_length); int m; char_2_dimension(m, plaintext_length); - SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec, + SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec, numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); int lgp0 = numBits(nb.min_p0(false, 0)); int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up); @@ -590,6 +590,9 @@ void char_2_dimension(int& m, int& lg2) m=5797; lg2=40; break; + case 16: + m = 13107; + break; default: throw runtime_error("field size not supported"); break; diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index c0a2ecfea..acaba70b6 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -52,7 +52,7 @@ void generate_setup(int nparties, int lgp, int lg2, // semi-homomorphic, includes slack template int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, FD& FieldD, bool round_up); + FHE_Params& params, FD& FieldD, bool round_up, int n = 1); // field-independent semi-homomorphic setup int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index 7ab8e5172..f2e151c42 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -39,6 +39,7 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); + int matrix_dim = params.get_matrix_dim(); #ifdef NOISY cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; cout << "V_s: " << V_s << endl; @@ -48,9 +49,11 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, cout << "log(slack): " << slack << endl; cout << "B_clean: " << B_clean << endl; cout << "B_scale: " << B_scale << endl; + cout << "matrix dimension: " << matrix_dim << endl; #endif - drown = 1 + n * (bigint(1) << sec); + assert(matrix_dim > 0); + drown = 1 + matrix_dim * n * (bigint(1) << sec); } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 812560a3a..554d4dc10 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -50,6 +50,7 @@ void Ring_Element::prepare_push() void Ring_Element::allocate() { + assert(FFTD); element.resize(FFTD->phi_m()); } diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index af7a664b5..531df90f7 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b) } } +void Rq_Element::add(octetStream& os) +{ + Rq_Element tmp(*this); + tmp.unpack(os); + *this += tmp; +} + void Rq_Element::randomize(PRNG& G,int l) { set_level(l); @@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p) // Now add delta back onto a0 Rq_Element bb(b0,b1); - add(*this,*this,bb); + ::add(*this,*this,bb); // Now divide by p1 mod p0 modp p1_inv,pp; diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index d5e718419..a58cb7de0 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -93,12 +93,14 @@ class Rq_Element friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); + void add(octetStream& os); + template Rq_Element& operator+=(const vector& other); - Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; } + Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; } - Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; } + Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; } Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; } template Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; } @@ -176,7 +178,7 @@ Rq_Element& Rq_Element::operator+=(const vector& other) { Rq_Element tmp = *this; tmp.from(Iterator(other), lev); - add(*this, *this, tmp); + ::add(*this, *this, tmp); return *this; } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 0f5d1fe86..48a8a6ef8 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -203,7 +203,7 @@ template void PartSetup::secure_init(Player& P, MachineBase& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); } template diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 732904b39..92632002c 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -130,6 +130,13 @@ void Multiplier::report_size(ReportType type, MemoryUsage& res) res += memory_usage; } +template +const vector& Multiplier::get_multiplicands( + const vector >& others_ct, const FHE_PK&) +{ + return others_ct[P.get_full_player().get_player(-P.get_offset())]; +} + template class Multiplier; template class Multiplier; diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h index e2e1ce660..9ab517a66 100644 --- a/FHEOffline/Multiplier.h +++ b/FHEOffline/Multiplier.h @@ -55,6 +55,9 @@ class Multiplier size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); size_t report_volatile() { return volatile_capacity; } + + const vector& get_multiplicands( + const vector>& others_ct, const FHE_PK&); }; #endif /* FHEOFFLINE_MULTIPLIER_H_ */ diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index bba83b5fd..047c84f2c 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -9,6 +9,7 @@ #include "Math/Setup.h" #include "FHEOffline/Proof.h" #include "FHEOffline/PairwiseMachine.h" +#include "FHEOffline/TemiSetup.h" #include "Tools/Commit.h" #include "Tools/Bundle.h" #include "Processor/OnlineOptions.h" @@ -53,7 +54,7 @@ void PairwiseSetup::init(const Player& P, int sec, int plaintext_length, template void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); alpha = FieldD; machine.sk = FHE_SK(params, FieldD.get_prime()); for (auto& pk : machine.other_pks) @@ -62,13 +63,14 @@ void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int pla template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec) + int plaintext_length, int sec, FHE_Params& params) { machine.sec = sec; sec = max(sec, 40); machine.drown_sec = sec; string filename = PREP_DIR + T::name() + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-" + + to_string(params.get_matrix_dim()) + "-" + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); @@ -85,7 +87,6 @@ void secure_init(T& setup, Player& P, U& machine, { cout << "Finding parameters for security " << sec << " and field size ~2^" << plaintext_length << endl; - setup.params = setup.params.n_mults(); setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; @@ -208,5 +209,8 @@ void PairwiseSetup::set_alphai(T alphai) template class PairwiseSetup; template class PairwiseSetup; -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); + +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); diff --git a/FHEOffline/PairwiseSetup.h b/FHEOffline/PairwiseSetup.h index 8e16eaf34..f6482edec 100644 --- a/FHEOffline/PairwiseSetup.h +++ b/FHEOffline/PairwiseSetup.h @@ -15,7 +15,7 @@ class MachineBase; template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec); + int plaintext_length, int sec, FHE_Params& params); template class PairwiseSetup diff --git a/FHEOffline/SimpleDistDecrypt.cpp b/FHEOffline/SimpleDistDecrypt.cpp index 3774cd3c1..c8b923123 100644 --- a/FHEOffline/SimpleDistDecrypt.cpp +++ b/FHEOffline/SimpleDistDecrypt.cpp @@ -18,7 +18,12 @@ void SimpleDistDecrypt::reshare(Plaintext& EC) { (void)EC; + m = reshare(cm); +} +template +Plaintext_ SimpleDistDecrypt::reshare(const Ciphertext& cm) +{ PRNG G; G.ReSeed(); this->f.randomize(G, Full); @@ -27,10 +32,13 @@ void SimpleDistDecrypt::reshare(Plaintextrun(cm); // Step 4 + Plaintext_ m(this->f.get_field()); if (this->P.my_num()==0) { sub(m,this->mf,this->f); } else { m=this->f; m.negate(); } + + return m; } diff --git a/FHEOffline/SimpleDistDecrypt.h b/FHEOffline/SimpleDistDecrypt.h index 9589f15a1..c929a7990 100644 --- a/FHEOffline/SimpleDistDecrypt.h +++ b/FHEOffline/SimpleDistDecrypt.h @@ -20,6 +20,7 @@ class SimpleDistDecrypt : public DistDecrypt void reshare(Plaintext& m, const Ciphertext& cm, EncCommitBase& EC); + Plaintext_ reshare(const Ciphertext& cm); }; #endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */ diff --git a/FHEOffline/TemiSetup.cpp b/FHEOffline/TemiSetup.cpp new file mode 100644 index 000000000..fc222ed51 --- /dev/null +++ b/FHEOffline/TemiSetup.cpp @@ -0,0 +1,59 @@ +/* + * TemiSetup.cpp + * + */ + +#include "TemiSetup.h" +#include "PairwiseSetup.h" +#include "FHE/NTL-Subs.h" +#include "Protocols/HemiOptions.h" + +template +TemiSetup::TemiSetup() +{ + this->params = FHE_Params(0); + this->pk = {this->params, 0}; + this->sk = {this->params, 0}; + this->calpha = this->params; + this->params.set_matrix_dim( + HemiOptions::singleton.plain_matmul ? + 1 : OnlineOptions::singleton.batch_size); +} + +template +void TemiSetup::secure_init(Player& P, int plaintext_length) +{ + MachineBase machine; + ::secure_init(*this, P, machine, plaintext_length, 0, this->params); +} + +template +void TemiSetup::generate(Player& P, MachineBase&, + int plaintext_length, int sec) +{ + generate_semi_setup(plaintext_length, sec, this->params, this->FieldD, + false, P.num_players()); + this->sk = {this->params, this->FieldD.get_prime()}; + this->pk = {this->params, this->FieldD.get_prime()}; +} + +template +void TemiSetup::key_and_mac_generation(Player& P, MachineBase&, int, + true_type) +{ + Rq_Element a(this->params); + GlobalPRNG GG(P); + a.randomize(GG); + SeededPRNG G; + auto sk = this->pk.sample_secret_key(G); + this->sk.assign(sk); + this->pk.partial_key_gen(sk, a, G); + TreeSum ts; + vector pks; + pks.push_back(this->pk.b()); + ts.run(pks, P); + this->pk.assign(this->pk.a(), pks[0]); +} + +template class TemiSetup; +template class TemiSetup; diff --git a/FHEOffline/TemiSetup.h b/FHEOffline/TemiSetup.h new file mode 100644 index 000000000..483cb0ee6 --- /dev/null +++ b/FHEOffline/TemiSetup.h @@ -0,0 +1,34 @@ +/* + * TemiSetup.h + * + */ + +#ifndef FHEOFFLINE_TEMISETUP_H_ +#define FHEOFFLINE_TEMISETUP_H_ + +#include "FHE/FHE_Keys.h" +#include "FHEOffline/SimpleMachine.h" + +template +class TemiSetup : public PartSetup +{ +public: + static string name() + { + return "TemiParams"; + } + + static string protocol_name(int) + { + return "Temi"; + } + + TemiSetup(); + + void secure_init(Player& P, int plaintext_length); + void generate(Player& P, MachineBase&, int plaintext_length, int sec); + + void key_and_mac_generation(Player& P, MachineBase&, int, true_type); +}; + +#endif /* FHEOFFLINE_TEMISETUP_H_ */ diff --git a/GC/Memory.h b/GC/Memory.h index 359677a20..006a91d94 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -47,11 +47,11 @@ inline void Memory::check_index(Integer index) const ss << T::type_string() << " memory overflow: " << i << "/" << vector::size(); throw Processor_Error(ss.str()); } -#endif #ifdef DEBUG_MEMORY cout << typeid(T).name() << " at " << this << " index " << i << ": " << vector::operator[](i) << endl; #endif +#endif } template diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 48f75b8f2..6d9f26525 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -122,6 +122,7 @@ class RepSecretBase : public FixedVec, public ShareSecret static const bool dishonest_majority = false; static const bool variable_players = false; static const bool needs_ot = false; + static const bool has_mac = false; static string type_string() { return "replicated secret"; } static string phase_name() { return "Replicated computation"; } diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 9b6c84782..9cdde3dc7 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -49,6 +49,7 @@ class VectorSecret : public Secret static const bool dishonest_majority = T::dishonest_majority; static const bool variable_players = T::variable_players; static const bool needs_ot = T::needs_ot; + static const bool has_mac = T::has_mac; static const bool expensive_triples = false; static const int default_length = 64; diff --git a/GC/instructions.h b/GC/instructions.h index 66ae46d22..49443cc23 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -55,7 +55,7 @@ X(BITDECC, PROC.bitdecc(EXTRA, C0)) \ X(SHRCBI, C0 = PC1 >> IMM) \ X(SHLCBI, C0 = PC1 << IMM) \ - X(LDBITS, S0.load_clear(REG1, IMM)) \ + X(LDBITS, S0.load_clear(REG1, int(IMM))) \ X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \ X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \ X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \ diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 7697c5124..9f18d3a6f 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -23,6 +23,7 @@ #include "Protocols/Shamir.hpp" #include "Protocols/ShamirMC.hpp" #include "Protocols/MaliciousShamirMC.hpp" +#include "Protocols/MaliciousShamirPO.hpp" #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" diff --git a/Machines/temi-party.cpp b/Machines/temi-party.cpp new file mode 100644 index 000000000..12e99dc27 --- /dev/null +++ b/Machines/temi-party.cpp @@ -0,0 +1,37 @@ +/* + * temi-party.cpp + * + */ + +#include "Protocols/TemiShare.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "FHE/P2Data.h" +#include "Tools/ezOptionParser.h" +#include "GC/SemiSecret.h" +#include "GC/SemiPrep.h" + +#include "Processor/FieldMachine.hpp" +#include "Protocols/TemiPrep.hpp" +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/SemiPrep.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/SemiMC.hpp" +#include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/Hemi.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/SemiHonestRepPrep.h" +#include "Math/gfp.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + HemiOptions::singleton = {opt, argc, argv}; + DishonestMajorityFieldMachine(argc, argv, + opt); +} diff --git a/Makefile b/Makefile index e40528b8c..4f558e1d6 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr all: overdrive she-offline -arithmetic: hemi-party.x soho-party.x gear +arithmetic: semi-he gear -include $(DEPS) include $(wildcard *.d static/*.d) @@ -87,6 +87,7 @@ she-offline: Check-Offline.x spdz2-offline.x overdrive: simple-offline.x pairwise-offline.x cnc-offline.x gear gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x +semi-he: hemi-party.x soho-party.x temi-party.x rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x @@ -210,6 +211,7 @@ static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) +temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) @@ -217,6 +219,7 @@ lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/Lo highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o atlas-party.x: GC/AtlasSecret.o static/hemi-party.x: $(FHEOBJS) +static/temi-party.x: $(FHEOBJS) static/soho-party.x: $(FHEOBJS) static/cowgear-party.x: $(FHEOBJS) static/chaigear-party.x: $(FHEOBJS) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 55983e0b0..c0b2373ed 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -14,11 +14,6 @@ using namespace std; #include "Tools/random.h" #include "field_types.h" -template class ReplicatedMC; -template class ReplicatedInput; -template class ReplicatedPrivateOutput; -template class Replicated; - template class FixedVec { diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index f30e71037..13d700fc1 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -233,7 +233,7 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* if (mpn_cmp(ans+T,prA,T+1)>=0) { mpn_sub_fixed_n(z,ans+T,prA); } else - { inline_mpn_copyi(z,ans+T,T); } + { inline_mpn_copyi(z,ans+T); } #else Mont_Mult(z, x, y, t); #endif diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 1a6fe41d1..f9491fb7b 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -18,15 +18,21 @@ bool gf2n_::useC; word gf2n_short_table[256][256]; -#define num_2_fields 6 +#define num_2_fields 7 /* Require * 2*(n-1)-64+t1<64 */ -int fields_2[num_2_fields][4] = { - {4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10},{63,1,0,0},{128,7,2,1}, - }; - +int fields_2[num_2_fields][4] = +{ + { 4, 1, 0, 0 }, + { 8, 4, 3, 1 }, + { 16, 5, 3, 1 }, + { 28, 1, 0, 0 }, + { 40, 20, 15, 10 }, + { 63, 1, 0, 0 }, + { 128, 7, 2, 1 }, +}; template void gf2n_::init_tables() diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index b55a6b7ec..b1c5642be 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -24,6 +24,12 @@ inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t si avx_memcpy(dest, src, size * sizeof(mp_limb_t)); } +template +inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src) +{ + avx_memcpy(dest, src); +} + inline void debug_print(const char* name, const mp_limb_t* x, int n) { (void)name, (void)x, (void)n; diff --git a/Networking/Player.h b/Networking/Player.h index 9c90dbd1f..ff4bdcd1d 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -542,6 +542,7 @@ class OffsetPlayer : public TwoPartyPlayer int other_player_num() const { return P.get_player(offset); } int num_players() const { return 2; } int get_offset() const { return offset; } + Player& get_full_player() const { return P; } void send(octetStream& o) const { P.send_to(P.get_player(offset), o); } void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o); } diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 988565854..730ffa6f6 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -206,6 +206,18 @@ void BaseOT::exec_base(bool new_receiver_inputs) receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]); } } + +#ifdef BASE_OT_DEBUG + for (j = 0; j < 4; j++) + for (k = 0; k < AES_BLK_SIZE; k++) + { + printf("%4d-th receiver key:", i+j); + for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]); + printf("\n"); + } + + printf("\n"); +#endif } } @@ -244,12 +256,6 @@ void BaseOT::exec_base(bool new_receiver_inputs) for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]); printf("\n"); } - if (ot_role & RECEIVER) - { - printf("%4d-th receiver key:", i+j); - for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]); - printf("\n"); - } } printf("\n"); diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index ef735279a..ea8239a5b 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -25,7 +25,7 @@ void Binary_File_IO::write_to_file(const string filename, if (start_pos != -1) { - long write_pos = start_pos * T::size(); + long write_pos = file_signature().get_total_length() + start_pos * T::size(); // fill with zeros if needed for (long i = outf.tellp(); i < write_pos; i++) outf.put(0); @@ -50,10 +50,13 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, inf.open(filename, ios::in | ios::binary); if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); } + check_file_signature(inf, filename).get_length(); + auto data_start = inf.tellg(); + int size_in_bytes = T::size() * buffer.size(); int n_read = 0; char read_buffer[size_in_bytes]; - inf.seekg(start_posn * T::size()); + inf.seekg(start_posn * T::size(), iostream::cur); do { inf.read(read_buffer + n_read, size_in_bytes - n_read); @@ -62,7 +65,9 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, if (inf.eof()) { stringstream ss; - ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes)."; + ss << "Got to EOF when reading from disk (expecting " << size_in_bytes + << " bytes from " << (long(data_start) + start_posn * T::size()) + << ")."; throw file_error(ss.str()); } if (inf.fail()) @@ -74,7 +79,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, } while (n_read < size_in_bytes); - end_posn = inf.tellg() / T::size(); + end_posn = (inf.tellg() - data_start) / T::size(); assert (end_posn == start_posn + int(buffer.size())); //Check if at end of file by getting 1 more char. diff --git a/Processor/Input.h b/Processor/Input.h index 98c6c83b0..728c81f6a 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -32,6 +32,15 @@ class InputBase Buffer buffer; Timer timer; + // Send my inputs (not generally available) + virtual void send_mine() { throw not_implemented(); } + // Get share for next input of mine (not generally available) + virtual T finalize_mine() { throw not_implemented(); } + // Store share for next input from ``player`` from buffer ``o`` + // in ``target`` (not generally available) + virtual void finalize_other(int, T&, octetStream&, int = -1) + { throw not_implemented(); } + public: vector os; int values_input; @@ -61,18 +70,12 @@ class InputBase /// Schedule input from other player virtual void add_other(int player, int n_bits = -1) = 0; /// Schedule input from all players - void add_from_all(const clear& input, int n_bits = -1); + void add_from_all(const typename T::open_type& input, int n_bits = -1); - /// Send my inputs - virtual void send_mine() = 0; /// Run input protocol for all players virtual void exchange(); - /// Get share for next input of mine - virtual T finalize_mine() = 0; - /// Store share for next input from ``player`` from buffer ``o`` in ``target`` - virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; - /// Get share for next input from ``player` + /// Get share for next input from ``player`` virtual T finalize(int player, int n_bits = -1); void raw_input(SubProcessor& proc, const vector& args, int size); diff --git a/Processor/Input.hpp b/Processor/Input.hpp index b9f7a77ab..246c9eb1d 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -113,7 +113,7 @@ void Input::add_other(int player, int) } template -void InputBase::add_from_all(const clear& input, int n_bits) +void InputBase::add_from_all(const typename T::open_type& input, int n_bits) { for (int i = 0; i < P->num_players(); i++) if (i == P->my_num()) diff --git a/Processor/Instruction.h b/Processor/Instruction.h index ca062cbcb..a7e1e3185 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -106,6 +106,7 @@ enum MATMULSM = 0xAB, CONV2DS = 0xAC, CHECK = 0xAF, + PRIVATEOUTPUT = 0xAD, // Data access TRIPLE = 0x50, BIT = 0x51, @@ -127,6 +128,7 @@ enum INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, INPUTPERSONAL = 0xF5, + SENDPERSONAL = 0xF6, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 25fa666f4..1bc46f94f 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -200,14 +200,17 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case USE: case USE_INP: case USE_EDABIT: - case STARTPRIVATEOUTPUT: - case GSTARTPRIVATEOUTPUT: - case STOPPRIVATEOUTPUT: - case GSTOPPRIVATEOUTPUT: case DIGESTC: + case INPUTMASK: + case GINPUTMASK: get_ints(r, s, 2); n = get_int(s); break; + case STARTPRIVATEOUTPUT: + case GSTARTPRIVATEOUTPUT: + case STOPPRIVATEOUTPUT: + case GSTOPPRIVATEOUTPUT: + throw runtime_error("two-stage private output not supported any more"); case USE_MATMUL: get_ints(r, s, 3); n = get_int(s); @@ -237,8 +240,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case PRINTREGB: case GPRINTREG: case LDINT: - case INPUTMASK: - case GINPUTMASK: case INV2M: case CONDPRINTSTR: case CONDPRINTSTRB: @@ -290,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case RAWINPUT: case GRAWINPUT: case INPUTPERSONAL: + case SENDPERSONAL: + case PRIVATEOUTPUT: case TRUNC_PR: case RUN_TAPE: num_var_args = get_int(s); @@ -599,6 +602,7 @@ int BaseInstruction::get_reg_type() const case PUBINPUT: case FLOATOUTPUT: case READSOCKETC: + case PRIVATEOUTPUT: return CINT; default: if (is_gf2n_instruction()) @@ -738,10 +742,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const skip = 1; break; case INPUTPERSONAL: + case PRIVATEOUTPUT: size_offset = -2; offset = 2; skip = 4; break; + case SENDPERSONAL: + size_offset = -2; + offset = 2; + skip = 5; + break; case READSOCKETS: case READSOCKETC: case READSOCKETINT: @@ -939,13 +949,11 @@ inline void Instruction::execute(Processor& Proc) const break; case INPUTMASK: Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n); - if (n == Proc.P.my_num()) - Proc.temp.rrp.output(Proc.private_output, false); + Proc.write_Cp(r[1], Proc.temp.rrp); break; case GINPUTMASK: Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n); - if (n == Proc.P.my_num()) - Proc.temp.ans2.output(Proc.private_output, false); + Proc.write_C2(r[1], Proc.temp.ans2); break; case INPUT: sint::Input::template input>(Proc.Procp, start, size); @@ -974,6 +982,12 @@ inline void Instruction::execute(Processor& Proc) const case INPUTPERSONAL: Proc.Procp.input_personal(start); return; + case SENDPERSONAL: + Proc.Procp.send_personal(start); + return; + case PRIVATEOUTPUT: + Proc.Procp.private_output(start); + return; // Note: Fp version has different semantics for NOTC than GNOTC case NOTC: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); @@ -1202,18 +1216,6 @@ inline void Instruction::execute(Processor& Proc) const Proc.binary_output.write((char*) &tmp, sizeof(double)); } break; - case STARTPRIVATEOUTPUT: - Proc.privateOutputp.start(n,r[0],r[1]); - break; - case GSTARTPRIVATEOUTPUT: - Proc.privateOutput2.start(n,r[0],r[1]); - break; - case STOPPRIVATEOUTPUT: - Proc.privateOutputp.stop(n,r[0],r[1]); - break; - case GSTOPPRIVATEOUTPUT: - Proc.privateOutput2.stop(n,r[0],r[1]); - break; case PREP: Procp.DataF.get(Proc.Procp.get_S(), r, start, size); return; diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index a43c9d475..cd318f1aa 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -97,12 +97,19 @@ Machine::Machine(int my_number, Names& playerNames, // initialize persistence if necessary for (auto& prog : progs) { - if (prog.writes_persistance) + if (prog.writes_persistence) { string filename = Binary_File_IO::filename(my_number); ifstream pers(filename); - if (pers.fail()) - ofstream pers(filename, ios::binary); + try + { + check_file_signature(pers, filename); + } + catch (signature_mismatch&) + { + ofstream pers(filename, ios::binary); + file_signature().output(pers); + } break; } } @@ -418,12 +425,14 @@ void Machine::run() cerr << "Full broadcast" << endl; #endif +#ifdef CHOP_MEMORY // Reduce memory size to speed up unsigned max_size = 1 << 20; if (M2.size_s() > max_size) M2.resize_s(max_size); if (Mp.size_s() > max_size) Mp.resize_s(max_size); +#endif // Write out the memory to use next time ofstream outf(memory_filename(), ios::out | ios::binary); diff --git a/Processor/Memory.h b/Processor/Memory.h index 9ec02d2b8..1fbeda7ec 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -44,9 +44,9 @@ class Memory static void check_index(const vector& M, size_t i) { (void) M, (void) i; -#ifdef NO_CHECK_INDEX +#ifndef NO_CHECK_INDEX if (i >= M.size()) - throw overflow("memory", i, M.size()); + throw overflow(U::type_string() + " memory", i, M.size()); #endif } diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index c3c3e01bf..ef767441b 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -19,6 +19,9 @@ void MemoryPart::minimum_size(size_t size) { if (size > this->size()) this->resize(size); +#ifdef DEBUG_MEMORY_SIZE + cerr << T::type_string() << " memory has now size " << this->size() << endl; +#endif } catch (bad_alloc&) { @@ -58,9 +61,9 @@ istream& operator>>(istream& s,Memory& M) int len; s >> len; - M.resize_s(len); + M.MS.minimum_size(len); s >> len; - M.resize_c(len); + M.MC.minimum_size(len); s.seekg(1, istream::cur); for (unsigned int i=0; i& proc; + typename T::MAC_Check MC; deque masks; public: - PrivateOutput(SubProcessor& proc) : proc(proc) { }; + PrivateOutput(SubProcessor& proc); + ~PrivateOutput(); - void start(int player, int target, int source); - void stop(int player, int dest, int source); - - T start(int player, const T& source); - typename T::clear stop(int player, const typename T::clear& masked); + void prepare_sending(const T& source, int player); + void exchange(); + typename T::clear finalize(int player); }; #endif /* PROCESSOR_PRIVATEOUTPUT_H_ */ diff --git a/Processor/PrivateOutput.hpp b/Processor/PrivateOutput.hpp index 977e7e15d..d2cee8a14 100644 --- a/Processor/PrivateOutput.hpp +++ b/Processor/PrivateOutput.hpp @@ -7,13 +7,21 @@ #include "Processor.h" template -void PrivateOutput::start(int player, int target, int source) +PrivateOutput::PrivateOutput(SubProcessor& proc) : + proc(proc), MC(proc.MC.get_alphai()) { - proc.get_S_ref(target) = start(player, proc.get_S_ref(source)); + MC.init_open(proc.P); + MC.set_prep(proc.DataF); } template -T PrivateOutput::start(int player, const T& source) +PrivateOutput::~PrivateOutput() +{ + MC.Check(proc.P); +} + +template +void PrivateOutput::prepare_sending(const T& source, int player) { assert (player < proc.P.num_players()); open_type mask; @@ -24,26 +32,25 @@ T PrivateOutput::start(int player, const T& source) if (player == proc.P.my_num()) masks.push_back(mask); - return res; + MC.prepare_open(res); } template -void PrivateOutput::stop(int player, int dest, int source) +void PrivateOutput::exchange() { - auto& value = proc.get_C_ref(dest); - value = stop(player, proc.get_C_ref(source)); - if (proc.Proc) - value.output(proc.Proc->private_output, false); + MC.exchange(proc.P); } template -typename T::clear PrivateOutput::stop(int player, const typename T::clear& source) +typename T::clear PrivateOutput::finalize(int player) { - typename T::clear value; + auto res = MC.finalize_open(); + if (player == proc.P.my_num()) { - value = source - masks.front(); + res -= masks.front(); masks.pop_front(); } - return value; + + return res; } diff --git a/Processor/Processor.h b/Processor/Processor.h index c91b677bf..38ea7f258 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -71,6 +71,8 @@ class SubProcessor void conv2ds(const Instruction& instruction); void input_personal(const vector& args); + void send_personal(const vector& args); + void private_output(const vector& args); CheckVector& get_S() { @@ -110,7 +112,6 @@ class ArithmeticProcessor : public ProcessorBase ifstream private_input; ifstream public_input; ofstream public_output; - ofstream private_output; ofstream binary_output; int sent, rounds; @@ -172,9 +173,6 @@ class Processor : public ArithmeticProcessor SubProcessor Proc2; SubProcessor Procp; - typename sgf2n::PrivateOutput privateOutput2; - typename sint::PrivateOutput privateOutputp; - unsigned int PC; TempVars temp; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index c55a6dfc1..d74594b3d 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -4,9 +4,8 @@ #include "Processor/Processor.h" #include "Processor/Program.h" #include "GC/square64.h" +#include "SpecificPrivateOutput.h" -#include "Protocols/ReplicatedInput.hpp" -#include "Protocols/ReplicatedPrivateOutput.hpp" #include "Processor/ProcessorBase.hpp" #include "GC/Processor.hpp" #include "GC/ShareThread.hpp" @@ -63,7 +62,6 @@ Processor::Processor(int thread_num,Player& P, share_thread(DataF.DataFb, P, machine.get_bit_mac_key()), Procb(machine.bit_memories), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), - privateOutput2(Proc2),privateOutputp(Procp), external_clients(P.my_num()), binary_file_io(Binary_File_IO()) { @@ -74,7 +72,6 @@ Processor::Processor(int thread_num,Player& P, private_input_filename = (get_filename(PREP_DIR "Private-Input-",true)); private_input.open(private_input_filename.c_str()); public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out); - private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out); binary_output.open( get_parameterized_filename(P.my_num(), thread_num, PREP_DIR "Binary-Output"), ios_base::out); @@ -654,6 +651,37 @@ void SubProcessor::input_personal(const vector& args) S[args[i + 2] + j] = input.finalize(args[i + 1]); } +template +void SubProcessor::private_output(const vector& args) +{ + typename T::PrivateOutput output(*this); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + { + int player = args[i + 1]; + output.prepare_sending(S.at(args[i + 3] + j), player); + } + output.exchange(); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + C.at(args[i + 2] + j) = output.finalize(args[i + 1]); +} + +template +void SubProcessor::send_personal(const vector& args) +{ + octetStreams to_send(P), to_receive(P); + for (size_t i = 0; i < args.size(); i += 5) + if (args[i + 3] == P.my_num()) + for (int j = 0; j < args[i]; j++) + C[args[i + 4] + j].pack(to_send[args[i + 1]]); + P.send_receive_all(to_send, to_receive); + for (size_t i = 0; i < args.size(); i += 5) + if (args[i + 1] == P.my_num()) + for (int j = 0; j < args[i]; j++) + C[args[i + 2] + j].unpack(to_receive[args[i + 3]]); +} + template typename sint::clear Processor::get_inverse2(unsigned m) { diff --git a/Processor/Program.cpp b/Processor/Program.cpp index c33039428..dac73400b 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -23,7 +23,7 @@ void Program::compute_constants() max_mem[reg_type] = max(max_mem[reg_type], p[i].get_mem(RegType(reg_type))); } - writes_persistance |= p[i].opcode == WRITEFILESHARE; + writes_persistence |= p[i].opcode == WRITEFILESHARE; } } diff --git a/Processor/Program.h b/Processor/Program.h index a41c9e2a6..87a263f08 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -30,10 +30,10 @@ class Program public: - bool writes_persistance; + bool writes_persistence; Program(int nplayers) : offline_data_used(nplayers), - unknown_usage(false), writes_persistance(false) + unknown_usage(false), writes_persistence(false) { compute_constants(); } // Read in a program diff --git a/Processor/SpecificPrivateOutput.h b/Processor/SpecificPrivateOutput.h new file mode 100644 index 000000000..7878db1cd --- /dev/null +++ b/Processor/SpecificPrivateOutput.h @@ -0,0 +1,65 @@ +/* + * SpecificPrivateOutput.h + * + */ + +#ifndef PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ +#define PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ + +template +class SpecificPrivateOutput +{ + deque secrets; + vector pos; + Player& P; + vector active; + +public: + SpecificPrivateOutput(SubProcessor& proc) : + P(proc.P) + { + for (int i = 0; i < P.num_players(); i++) + pos.push_back(new typename T::PO(proc.P)); + active.resize(P.num_players()); + } + + ~SpecificPrivateOutput() + { + for (auto& x : pos) + delete x; + } + + void prepare_sending(const T& secret, int player) + { + pos[player]->prepare_sending(secret, player); + if (P.my_num() == player) + secrets.push_back(secret); + active[player] = true; + } + + void exchange() + { + for (int i = 0; i < this->P.num_players(); i++) + if (active[i]) + { + if (i == this->P.my_num()) + pos[i]->receive(); + else + pos[i]->send(i); + } + } + + typename T::clear finalize(int player) + { + if (player == this->P.my_num()) + { + T secret = secrets.front(); + secrets.pop_front(); + return pos[player]->finalize(secret); + } + else + return {}; + } +}; + +#endif /* PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ */ diff --git a/Programs/Source/falcon_alex.mpc b/Programs/Source/falcon_alex.mpc new file mode 100644 index 000000000..3c535248f --- /dev/null +++ b/Programs/Source/falcon_alex.mpc @@ -0,0 +1,100 @@ +from Compiler.ml import keras +import Compiler.ml as tf + +try: + n_epochs = int(program.args[1]) +except (ValueError, IndexError): + n_epochs = 10 + +try: + batch_size = int(program.args[2]) +except (ValueError, IndexError): + batch_size = 128 + +try: + n_threads = int(program.args[3]) +except (ValueError, IndexError): + n_threads = 36 + +#Instantiation +AlexNet = [] + +padding = 'same' +batchnorm = 'batchnorm' in program.args + +#1st Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=96, input_shape=(32,32,3), kernel_size=(11,11), strides=(4,4), padding=9)) +AlexNet.append(keras.layers.Activation('relu')) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=3, strides=(2,2))) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) + +#2nd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=1)) + +#3rd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#4th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#5th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#Passing it to a Fully Connected layer +# 1st Fully Connected Layer +AlexNet.append(keras.layers.Dense(256)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#2nd Fully Connected Layer +AlexNet.append(keras.layers.Dense(256)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#Output Layer +AlexNet.append(keras.layers.Dense(10)) + + +tf.set_n_threads(n_threads) +program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) + +training_samples = MultiArray([50000, 32, 32, 3], sfix) +training_labels = MultiArray([50000, 10], sint) + +test_samples = MultiArray([10000, 32, 32, 3], sfix) +test_labels = MultiArray([10000, 10], sint) + +if 'no_acc' not in program.args: + training_labels.input_from(0) + training_samples.input_from(0) + + test_labels.input_from(0) + test_samples.input_from(0) + +model = tf.keras.models.Sequential(AlexNet) + +model.compile_by_args(program) + +model.build(training_samples.sizes) +model.summary() + +opt = model.fit( + training_samples, + training_labels, + epochs=n_epochs, + batch_size=batch_size, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/keras_cifar_lenet.mpc b/Programs/Source/keras_cifar_lenet.mpc new file mode 100644 index 000000000..882d2e187 --- /dev/null +++ b/Programs/Source/keras_cifar_lenet.mpc @@ -0,0 +1,45 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +training_samples = MultiArray([50000, 32, 32, 3], sfix) +training_labels = MultiArray([50000, 10], sint) + +test_samples = MultiArray([10000, 32, 32, 3], sfix) +test_labels = MultiArray([10000, 10], sint) + +training_labels.input_from(0) +training_samples.input_from(0) + +test_labels.input_from(0) +test_samples.input_from(0) + +from Compiler import ml +tf = ml +ml.set_n_threads(36) + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +optim = tf.keras.optimizers.Adam(amsgrad=True) + +model.compile(optimizer=optim) + +opt = model.fit( + training_samples, + training_labels, + epochs=10, + batch_size=128, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/keras_mnist_dense.mpc b/Programs/Source/keras_mnist_dense.mpc index a525c0650..76b1e23f5 100644 --- a/Programs/Source/keras_mnist_dense.mpc +++ b/Programs/Source/keras_mnist_dense.mpc @@ -21,7 +21,8 @@ tf = ml layers = [ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), - tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(128), + tf.keras.layers.Activation('relu'), tf.keras.layers.Dense(10, activation='softmax') ] diff --git a/Programs/Source/keras_mnist_lenet.mpc b/Programs/Source/keras_mnist_lenet.mpc index 9fdac27fd..674cf4036 100644 --- a/Programs/Source/keras_mnist_lenet.mpc +++ b/Programs/Source/keras_mnist_lenet.mpc @@ -20,8 +20,21 @@ tf = ml layers = [ tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), +] + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ tf.keras.layers.MaxPooling2D(2), tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), +] + + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ tf.keras.layers.MaxPooling2D(2), tf.keras.layers.Flatten(), tf.keras.layers.Dropout(0.5), diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index 9dc8a6851..37cd73d2d 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -21,6 +21,8 @@ elif 'debug' in program.args: n_test = 100 elif 'debug5000' in program.args: N = n_test = 5000 +elif 'mini' in program.args: + N = n_test = 10 else: N = 60000 n_test = 10000 @@ -39,6 +41,7 @@ except: batch_size = N N = min(N, 10000) +batch_size = min(batch_size, N) ml.Layer.back_batch_size = batch_size try: @@ -71,6 +74,9 @@ else: ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml), ml.Dense(N, n_inner, 10, debug=debug_ml)] +if 'batchnorm' in program.args: + layers.insert(1, ml.BatchNorm([N, n_inner])) + if 'dropout' in program.args: for i in range(len(layers) - 1, 0, -1): layers.insert(i, ml.Dropout(N, n_inner)) diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index 6ea76b260..04ca11ad6 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -53,7 +53,7 @@ except: ml.Layer.back_batch_size = batch_size layers = [ - ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [n_examples, 24, 24, 20], (1, 1), 'VALID'), + ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [N, 24, 24, 20], (1, 1), 'VALID'), ml.MaxPool([N, 24, 24, 20]), ml.Relu([N, 12, 12, 20]), ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'), @@ -66,6 +66,12 @@ layers = [ layers += [ml.MultiOutput.from_args(program, n_examples, 10)] +if 'batchnorm' in program.args: + for arg in program.args: + assert not arg.startswith('dropout') + layers.insert(4, ml.BatchNorm([N, 8, 8, 50], args=program.args)) + layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args)) + if 'dropout' in program.args or 'dropout2' in program.args: layers.insert(8, ml.Dropout(N, 500)) elif 'dropout.25' in program.args: diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp index c3a919b3d..9c6f0b9c8 100644 --- a/Protocols/Atlas.hpp +++ b/Protocols/Atlas.hpp @@ -85,6 +85,12 @@ void Atlas::exchange() resharing.add_mine(e); } + for (size_t i = 0; i < min(masks.size(), size_t(P.num_players())); i++) + { + int j = (base_king + i) % P.num_players(); + resharing.add_sender(j); + } + resharing.exchange(); } diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index e67b28a97..1eebd3b73 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -27,7 +27,7 @@ HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, new HemiMatrixPrep(dims[0], dims[1], dims[2], - dynamic_cast&>(processor.DataF))}); + dynamic_cast(processor.DataF))}); return *matrix_preps.at(dims); } diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index e48d92571..ea5a7211c 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -18,17 +18,18 @@ template class HemiMatrixPrep : public BufferPrep> { typedef BufferPrep> super; + typedef typename T::LivePrep LivePrep; int n_rows, n_inner, n_cols; bool swapped; DataPositions* usage; - HemiPrep* prep; + LivePrep* prep; HemiMatrixPrep(const HemiMatrixPrep&) = delete; public: - HemiMatrixPrep(int n_rows, int n_inner, int n_cols, HemiPrep& prep) : + HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) : super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), prep(&prep) { diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 82b28431c..f42212995 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -87,11 +87,10 @@ void HemiMatrixPrep::buffer_triples() assert(prep); auto& multipliers = prep->get_multipliers(); - assert(prep->pairwise_machine); - auto& FTD = prep->pairwise_machine->setup_p.FieldD; - auto& pk = prep->pairwise_machine->pk; + auto& FTD = prep->get_FTD(); + auto& pk = prep->get_pk(); int n_matrices = FTD.num_slots() / n_rows; -#ifdef VERBOSE +#ifdef VERBOSE_HE fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner, n_inner, n_cols); fflush(stderr); @@ -103,20 +102,23 @@ void HemiMatrixPrep::buffer_triples() AddableVector> C(n_matrices); MatrixRandMultJob job(C, A, B); - if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) + if (T::local_mul) { - auto& queues = BaseMachine::s().queues; - int start = queues.distribute(job, n_matrices); - job.begin = start; - job.end = n_matrices; - matrix_rand_mult(job); - queues.wrap_up(job); - } - else - { - job.begin = 0; - job.end = n_matrices; - matrix_rand_mult(job); + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) + { + auto& queues = BaseMachine::s().queues; + int start = queues.distribute(job, n_matrices); + job.begin = start; + job.end = n_matrices; + matrix_rand_mult(job); + queues.wrap_up(job); + } + else + { + job.begin = 0; + job.end = n_matrices; + matrix_rand_mult(job); + } } #ifdef VERBOSE_HE @@ -130,26 +132,35 @@ void HemiMatrixPrep::buffer_triples() assert(prep->proc); auto& P = prep->proc->P; - Bundle bundle(P); - bundle.mine.store(diag.ciphertexts); - P.unchecked_broadcast(bundle); vector> others_ct; - for (auto& os : bundle) + + if (T::local_mul or OnlineOptions::singleton.direct) + { + Bundle bundle(P); + bundle.mine.store(diag.ciphertexts); + P.unchecked_broadcast(bundle); + for (auto& os : bundle) + { + others_ct.push_back({}); + os.get(others_ct.back(), Ciphertext(pk)); + } + } + else { - others_ct.push_back({}); - os.get(others_ct.back(), Ciphertext(pk)); + others_ct.push_back(diag.ciphertexts); + TreeSum().run(others_ct[0], P); } for (int j = 0; j < n_cols; j++) for (auto m : multipliers) { -#ifdef VERBOSE +#ifdef VERBOSE_HE fprintf(stderr, "column %d with party offset %d at %f\n", j, m->get_offset(), timer.elapsed()); fflush(stderr); #endif Ciphertext C(pk); - auto& multiplicands = others_ct[P.get_player(-m->get_offset())]; + auto& multiplicands = m->get_multiplicands(others_ct, pk); if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) { auto& queues = BaseMachine::s().queues; @@ -160,7 +171,7 @@ void HemiMatrixPrep::buffer_triples() CipherPlainMultJob job(products, multiplicands, multiplicands2, true); int start = queues.distribute(job, n_inner); #ifdef VERBOSE_HE - fprintf(stderr, "from %d in central thread\n", start); + fprintf(stderr, "from %d in central thread at %f\n", start, timer.elapsed()); fflush(stderr); #endif for (int i = start; i < n_inner; i++) @@ -185,7 +196,10 @@ void HemiMatrixPrep::buffer_triples() m->add(products[j], C, BOTH, n_inner); } - C += diag.dediag(products, n_matrices); + if (T::local_mul) + C += diag.dediag(products, n_matrices); + else + C = diag.dediag(products, n_matrices); for (int i = 0; i < n_matrices; i++) if (swapped) diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index c43b43e95..b2b510aa0 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -34,6 +34,9 @@ class HemiPrep : public SemiHonestRingPrep static void basic_setup(Player& P); static void teardown(); + static const FHE_PK& get_pk(); + static const FD& get_FTD(); + HemiPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index 6cdd75476..c456424e5 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -34,6 +34,20 @@ void HemiPrep::basic_setup(Player& P) T::clear::template init(); } +template +const FHE_PK& HemiPrep::get_pk() +{ + assert(pairwise_machine); + return pairwise_machine->pk; +} + +template +const typename T::clear::FD& HemiPrep::get_FTD() +{ + assert(pairwise_machine); + return pairwise_machine->setup().FieldD; +} + template HemiPrep::~HemiPrep() diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index d299fb18f..4a85cbe34 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -27,6 +27,7 @@ class HemiShare : public SemiShare typedef HemiPrep LivePrep; static const bool needs_ot = false; + static const bool local_mul = true; static true_type triple_matmul; HemiShare() diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index 9ff92fb0e..be0fac61d 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -140,12 +140,12 @@ void KeyGenProtocol::output_to(int player, vector& opened, vector& shares) { PrivateOutput po(*proc); - vector masked; for (auto& share : shares) - masked.push_back(po.start(player, share)); - MC->POpen(opened, masked, P); + po.prepare_sending(share, player); + po.exchange(); + opened.resize(shares.size()); for (auto& x : opened) - x = po.stop(player, x); + x = po.finalize(player); } template diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 571f391ef..2250417d0 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -52,6 +52,7 @@ class TreeSum virtual ~TreeSum(); void run(vector& values, const Player& P); + T run(const T& value, const Player& P); octetStream& get_buffer() { return os; } @@ -210,6 +211,14 @@ void TreeSum::run(vector& values, const Player& P) finish(values, P); } +template +T TreeSum::run(const T& value, const Player& P) +{ + vector values = {value}; + run(values, P); + return values[0]; +} + template size_t TreeSum::report_size(ReportType type) { @@ -244,14 +253,6 @@ void add_openings(vector& values, const Player& P, int sum_players, int last_ MC.player_timers[sender].start(); P.wait_receive(sender, oss[j]); MC.player_timers[sender].stop(); - if ((unsigned)oss[j].get_length() < values.size() * T::size()) - { - stringstream ss; - ss << "Not enough information received, expected " - << values.size() * T::size() << " bytes, got " - << oss[j].get_length(); - throw Processor_Error(ss.str()); - } MC.timers[SUM].start(); for (unsigned int i=0; i::Check(const Player& P) auto& vals = this->vals; auto& macs = this->macs; auto& popen_cnt = this->popen_cnt; + assert(int(macs.size()) <= popen_cnt); if (popen_cnt < 10) { diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index c7d477ad4..5a60281c6 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -12,6 +12,8 @@ using namespace std; #include "Networking/Player.h" #include "Tools/PointerVector.h" +template class Preprocessing; + /** * Abstract base class for opening protocols */ @@ -61,6 +63,8 @@ class MAC_Check_Base virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); virtual const Player& get_check_player(const Player& P) const { return P; } + + virtual void set_prep(Preprocessing&) {} }; #endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */ diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index 63bfe63ac..ff33a6eea 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -17,6 +17,7 @@ class MalRepRingShare : public MaliciousRep3Share> { typedef SignedZ2 T; typedef MaliciousRep3Share super; + typedef MalRepRingShare This; public: const static int BIT_LENGTH = K; @@ -26,7 +27,8 @@ class MalRepRingShare : public MaliciousRep3Share> typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index f98e9797f..e6f3a8a6a 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -13,6 +13,7 @@ template class Beaver; template class MaliciousRepPrepWithBits; template class MaliciousRepPO; template class MaliciousRepPrep; +template class SpecificPrivateOutput; namespace GC { @@ -30,8 +31,8 @@ class MaliciousRep3Share : public Rep3Share typedef HashMaliciousRepMC> MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput> Input; - typedef ::PrivateOutput> PrivateOutput; typedef MaliciousRepPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef Rep3Share Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; diff --git a/Protocols/MaliciousShamirPO.h b/Protocols/MaliciousShamirPO.h index 65003d108..5bffe4f8e 100644 --- a/Protocols/MaliciousShamirPO.h +++ b/Protocols/MaliciousShamirPO.h @@ -9,13 +9,14 @@ template class MaliciousShamirPO { +protected: Player& P; octetStream to_send; vector to_receive; vector shares; - MaliciousShamirMC MC; + typename T::Direct_MC MC; public: MaliciousShamirPO(Player& P); diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index 47592981f..fee8e8292 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -13,6 +13,7 @@ template class MaliciousRepPrepWithBits; template class MaliciousRepPrep; template class MaliciousShamirPO; +template class SpecificPrivateOutput; namespace GC { @@ -23,14 +24,15 @@ template class MaliciousShamirShare : public ShamirShare { typedef ShamirShare super; + typedef MaliciousShamirShare This; public: typedef Beaver> Protocol; typedef MaliciousShamirMC MAC_Check; typedef MAC_Check Direct_MC; typedef ShamirInput Input; - typedef ::PrivateOutput PrivateOutput; typedef MaliciousShamirPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ShamirShare Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index fa3bc9f03..c90a5e277 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -76,12 +76,6 @@ class MamaShare : public Share_, MamaMac> return string(1, T::type_char()); } - static void read_or_generate_mac_key(string, Player&, mac_key_type& key) - { - SeededPRNG G; - key.randomize(G); - } - MamaShare() { } diff --git a/Protocols/PostSacriRepFieldShare.h b/Protocols/PostSacriRepFieldShare.h index a7fed8afb..06196762b 100644 --- a/Protocols/PostSacriRepFieldShare.h +++ b/Protocols/PostSacriRepFieldShare.h @@ -15,6 +15,7 @@ template class PostSacriRepFieldShare : public MaliciousRep3Share { typedef MaliciousRep3Share super; + typedef PostSacriRepFieldShare This; public: typedef typename super::clear clear; @@ -23,7 +24,8 @@ class PostSacriRepFieldShare : public MaliciousRep3Share typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MaliciousRepPrepWithBits LivePrep; PostSacriRepFieldShare() diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index d4f2ab0fd..7cbd483c4 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -17,6 +17,7 @@ template class PostSacriRepRingShare : public Rep3Share2 { typedef Rep3Share2 super; + typedef PostSacriRepRingShare This; public: static const int BIT_LENGTH = K; @@ -33,7 +34,8 @@ class PostSacriRepRingShare : public Rep3Share2 typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; typedef GC::MaliciousRepSecret bit_type; diff --git a/Protocols/ProtocolSet.h b/Protocols/ProtocolSet.h index e6a8eb525..09be88cb1 100644 --- a/Protocols/ProtocolSet.h +++ b/Protocols/ProtocolSet.h @@ -42,8 +42,13 @@ class ProtocolSet { } - ~ProtocolSet() + /** + * Run all protocol checks + */ + void check() { + protocol.check(); + output.Check(processor.P); } }; @@ -73,6 +78,15 @@ class BinaryProtocolSet *thread.protocol), input(output, prep, P) { } + + /** + * Run all protocol checks + */ + void check() + { + protocol.check(); + output.Check(protocol.P); + } }; /** @@ -102,6 +116,15 @@ class MixedProtocolSet arithmetic.protocol), input(arithmetic.input) { } + + /** + * Run all protocol checks + */ + void check() + { + arithmetic.check(); + binary.check(); + } }; #endif /* PROTOCOLS_PROTOCOLSET_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index e85065ac0..44853b79a 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -15,7 +15,8 @@ template class ReplicatedPrep; template class ReplicatedRingPrep; -template class PrivateOutput; +template class ReplicatedPO; +template class SpecificPrivateOutput; template class RepShare : public FixedVec, public ShareInterface @@ -99,6 +100,7 @@ template class Rep3Share : public RepShare { typedef RepShare super; + typedef Rep3Share This; public: typedef T clear; @@ -107,7 +109,8 @@ class Rep3Share : public RepShare typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef Rep3Share Honest; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 23f28cf9b..e52d160bb 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -24,7 +24,8 @@ class Rep3Share2 : public Rep3Share> typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep2k LivePrep; typedef Rep3Share2 Honest; typedef SignedZ2 clear; diff --git a/Protocols/Rep4Input.h b/Protocols/Rep4Input.h index f1bc29af9..04acd0043 100644 --- a/Protocols/Rep4Input.h +++ b/Protocols/Rep4Input.h @@ -31,7 +31,6 @@ class Rep4Input : public InputBase void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player, int n_bits = -1); - void send_mine(); void exchange(); T finalize_mine(); diff --git a/Protocols/Rep4Input.hpp b/Protocols/Rep4Input.hpp index 5600b45c7..48844396b 100644 --- a/Protocols/Rep4Input.hpp +++ b/Protocols/Rep4Input.hpp @@ -64,12 +64,6 @@ void Rep4Input::add_other(int player, int) results[player].push_back(res); } -template -void Rep4Input::send_mine() -{ - throw not_implemented(); -} - template void Rep4Input::exchange() { diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 67527a208..2357d0f5e 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -19,10 +19,6 @@ using namespace std; template class SubProcessor; template class ReplicatedMC; template class ReplicatedInput; -template class ReplicatedPrivateOutput; -template class Share; -template class Rep3Share; -template class MAC_Check_Base; template class Preprocessing; class Instruction; @@ -141,9 +137,6 @@ class Replicated : public ReplicatedBase, public ProtocolBase void trunc_pr(const vector& regs, int size, U& proc, false_type); public: - typedef ReplicatedMC MAC_Check; - typedef ReplicatedInput Input; - static const bool uses_triples = false; Replicated(Player& P); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 374ed89b1..1a8a66b99 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -10,6 +10,7 @@ #include "Processor/Processor.h" #include "Processor/TruncPrTuple.h" #include "Tools/benchmarking.h" +#include "Tools/Bundle.h" #include "ReplicatedInput.h" #include "Rep3Share2k.h" @@ -162,14 +163,13 @@ void Replicated::prepare_mul(const T& x, } template -inline void Replicated::prepare_reshare(const typename T::clear& share, +void Replicated::prepare_reshare(const typename T::clear& share, int n) { - auto add_share = share; typename T::value_type tmp[2]; for (int i = 0; i < 2; i++) tmp[i].randomize(shared_prngs[i], n); - add_share += tmp[0] - tmp[1]; + auto add_share = share + tmp[0] - tmp[1]; add_share.pack(os[0], n); add_shares.push_back(add_share); } diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 916ee6b8f..b12f7f91f 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -56,16 +56,24 @@ BufferPrep::~BufferPrep() << " bit generation" << endl; #endif + auto field_type = T::clear::field_type(); + auto& my_usage = this->usage.files.at(field_type); + this->print_left("triples", triples.size() * T::default_length, type_string, this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) * T::default_length); + size_t used_bits = my_usage.at(DATA_BIT); + if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) + // add dabits with computation modulo power of two but without MAC + used_bits += my_usage.at(DATA_DABIT); + this->print_left("bits", bits.size(), type_string, used_bits); + #define X(KIND, TYPE) \ this->print_left(#KIND, KIND.size(), type_string, \ this->usage.files.at(T::clear::field_type()).at(TYPE)); X(squares, DATA_SQUARE) X(inverses, DATA_INVERSE) - X(bits, DATA_BIT) X(dabits, DATA_DABIT) #undef X @@ -601,17 +609,6 @@ void buffer_bits_from_players(vector>& player_bits, for (int i = 0; i < n_relevant_players; i++) for (auto& x : player_bits[i]) x = input.finalize((base_player + i) % P.num_players(), n_bits); -#if !defined(__clang__) && (__GNUC__ == 6) - // mitigate compiler bug - Bundle bundle(P); - P.unchecked_broadcast(bundle); -#endif -#ifdef DEBUG_BIT_SACRIFICE - typename T::MAC_Check MC; - for (int i = 0; i < n_relevant_players; i++) - for (auto& x : player_bits[i]) - assert((MC.open(x, P) == 0) or (MC.open(x, P) == 1)); -#endif } template @@ -1164,18 +1161,18 @@ void BufferPrep::buffer_inputs_as_usual(int player, SubProcessor* proc) typename T::clear r; r.randomize(G); input.add_mine(r); - this->inputs[player].push_back({input.finalize_mine(), r}); + this->inputs[player].push_back({input.finalize(player), r}); } - input.send_mine(); + input.exchange(); } else { - octetStream os; - P.receive_player(player, os); - T share; + for (int i = 0; i < buffer_size; i++) + input.add_other(player); + input.exchange(); for (int i = 0; i < buffer_size; i++) { - input.finalize_other(player, share, os); + auto share = input.finalize(player); this->inputs[player].push_back({share, 0}); } } diff --git a/Protocols/ReplicatedPrivateOutput.h b/Protocols/ReplicatedPrivateOutput.h deleted file mode 100644 index b9e546ca2..000000000 --- a/Protocols/ReplicatedPrivateOutput.h +++ /dev/null @@ -1,26 +0,0 @@ -/* - * ReplicatedPrivateOutput.h - * - */ - -#ifndef PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ -#define PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ - -template -class SubProcessor; -template -class Share; - -template -class ReplicatedPrivateOutput -{ - SubProcessor& proc; - -public: - ReplicatedPrivateOutput(SubProcessor& proc); - - void start(int player, int target, int source); - void stop(int player, int source); -}; - -#endif /* PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ */ diff --git a/Protocols/ReplicatedPrivateOutput.hpp b/Protocols/ReplicatedPrivateOutput.hpp deleted file mode 100644 index d34872235..000000000 --- a/Protocols/ReplicatedPrivateOutput.hpp +++ /dev/null @@ -1,30 +0,0 @@ -/* - * ReplicatedPrivateOutput.cpp - * - */ - -#include "ReplicatedPrivateOutput.h" -#include "Processor/Processor.h" -#include "Math/FixedVec.h" -#include "Math/Integer.h" - -template -inline ReplicatedPrivateOutput::ReplicatedPrivateOutput( - SubProcessor& proc) : - proc(proc) -{ -} - -template -void ReplicatedPrivateOutput::start(int player, int target, - int source) -{ - (void)player, (void)target, (void)source; - throw runtime_error("not implemented, use PrivateOutput"); -} - -template -void ReplicatedPrivateOutput::stop(int player, int source) -{ - (void)player, (void)source; -} diff --git a/Protocols/Semi.h b/Protocols/Semi.h index e290ca0eb..5f63a9d62 100644 --- a/Protocols/Semi.h +++ b/Protocols/Semi.h @@ -71,6 +71,12 @@ class Semi : public SPDZ proc.get_S()[info.source_base + i] >> info.m; } } + + void buffer_random() + { + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->random.push_back(G.get()); + } }; #endif /* PROTOCOLS_SEMI_H_ */ diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 87a1e08e5..4fc265b7c 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -14,34 +14,33 @@ template class SemiMC; * Additive secret sharing input protocol */ template -class SemiInput : public IndividualInput +class SemiInput : public InputBase { - SeededPRNG secure_prng; + vector send_prngs; + vector recv_prngs; + Player& P; + vector> shares; public: - SemiInput(SubProcessor& proc, SemiMC& MC) : - IndividualInput(proc) + SemiInput(SubProcessor& proc, SemiMC&) : + SemiInput(&proc, proc.P) { - (void) MC; } - SemiInput(SubProcessor* proc, Player& P) : - IndividualInput(proc, P) - { - } + SemiInput(SubProcessor* proc, Player& P); SemiInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : - SemiInput(P) + SemiInput(0, P) { (void) MC, (void) prep; } - SemiInput(Player& P) : - IndividualInput(0, P) - { - } - + void reset(int player); void add_mine(const typename T::clear& input, int n_bits = -1); + void add_other(int player, int n_bits = -1); + void exchange(); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); + T finalize_mine(); }; #endif /* PROTOCOLS_SEMIINPUT_H_ */ diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index 28673250f..3ed1feefe 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -11,22 +11,64 @@ #include "ShamirInput.hpp" template -void SemiInput::add_mine(const typename T::clear& input, int n_bits) +SemiInput::SemiInput(SubProcessor* proc, Player& P) : + InputBase(proc), P(P) +{ + shares.resize(P.num_players()); + vector to_send(P.num_players()), to_receive; + for (int i = 0; i < P.num_players(); i++) + { + send_prngs.push_back({}); + to_send[i].append(send_prngs.back().get_seed(), SEED_SIZE); + } + P.send_receive_all(to_send, to_receive); + recv_prngs.resize(P.num_players()); + for (int i = 0; i < P.num_players(); i++) + if (i != P.my_num()) + recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE)); + this->reset_all(P); +} + +template +void SemiInput::reset(int player) +{ + shares[player].clear(); +} + +template +void SemiInput::add_mine(const typename T::clear& input, int) { auto& P = this->P; typename T::open_type sum, share; for (int i = 0; i < P.num_players(); i++) { - if (i < P.num_players() - 1) - share.randomize(secure_prng, n_bits); - else - share = input - sum; - sum += share; - if (i == P.my_num()) - this->shares.push_back(share); - else - share.pack(this->os[i], n_bits); + if (i != P.my_num()) + sum += send_prngs[i].template get(); } + shares[P.my_num()].push_back(input - sum); +} + +template +void SemiInput::add_other(int, int) +{ +} + +template +void SemiInput::exchange() +{ +} + +template +void SemiInput::finalize_other(int player, T& target, octetStream&, + int) +{ + target = recv_prngs[player].template get(); +} + +template +T SemiInput::finalize_mine() +{ + return shares[P.my_num()].next(); } #endif diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index f722886eb..402173e98 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -27,7 +27,6 @@ class Shamir : public ProtocolBase { typedef typename T::open_type::Scalar U; - octetStreams os; vector reconstruction; U rec_factor; ShamirInput* resharing; diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 9fe10bdea..8bfdf70ea 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -69,8 +69,6 @@ int Shamir::get_n_relevant_players() template void Shamir::reset() { - os.reset(P); - if (resharing == 0) { resharing = new ShamirInput(0, P); @@ -78,6 +76,9 @@ void Shamir::reset() for (int i = 0; i < P.num_players(); i++) resharing->reset(i); + + for (int i = 0; i < n_mul_players; i++) + resharing->add_sender(i); } template @@ -92,37 +93,27 @@ template void Shamir::prepare_mul(const T& x, const T& y, int n) { (void) n; - auto add_share = x * y * rec_factor; if (P.my_num() < n_mul_players) - resharing->add_mine(add_share); + resharing->add_mine(x * y * rec_factor); } template void Shamir::exchange() { - vector senders(P.num_players(), false); - for (int i = 0; i < n_mul_players; i++) - senders[i] = true; - P.send_receive_all(senders, resharing->os, os); + assert(resharing); + resharing->exchange(); } template void Shamir::start_exchange() { - if (P.my_num() < n_mul_players) - for (int offset = 1; offset < P.num_players(); offset++) - P.send_relative(offset, resharing->os[P.get_player(offset)]); + resharing->start_exchange(); } template void Shamir::stop_exchange() { - for (int offset = 1; offset < P.num_players(); offset++) - { - int receive_from = P.get_player(-offset); - if (receive_from < n_mul_players) - P.receive_player(receive_from, os[receive_from]); - } + resharing->stop_exchange(); } template @@ -136,15 +127,8 @@ template T Shamir::finalize(int n_relevant_players) { ShamirShare res = U(0); - if (P.my_num() < n_relevant_players) - res = resharing->finalize_mine(); for (int i = 0; i < n_relevant_players; i++) - if (i != P.my_num()) - { - T tmp; - resharing->finalize_other(i, tmp, os[i]); - res += tmp; - } + res += resharing->finalize(i); return res; } @@ -259,7 +243,7 @@ vector Shamir::get_randoms(PRNG& G, int t) input.reset_all(P); int buffer_size = OnlineOptions::singleton.batch_size; for (int i = 0; i < buffer_size; i += hyper.size()) - input.add_mine(G.get()); + input.add_from_all(G.get()); input.exchange(); vector inputs; vector random; diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 023467077..91e093091 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -21,10 +21,11 @@ class IndividualInput : public PrepLessInput protected: Player& P; octetStreams os; + vector senders; public: IndividualInput(SubProcessor* proc, Player& P) : - PrepLessInput(proc), P(P) + PrepLessInput(proc), P(P), senders(P.num_players()) { this->reset_all(P); } @@ -34,10 +35,14 @@ class IndividualInput : public PrepLessInput } void reset(int player); + void add_sender(int player); void add_other(int player, int n_bits = -1); void send_mine(); void exchange(); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); + + void start_exchange(); + void stop_exchange(); }; /** diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index d84b09a6b..6d9992ad7 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -20,6 +20,8 @@ void IndividualInput::reset(int player) this->i_share = 0; os.reset(P); } + + senders[player] = false; } template @@ -68,12 +70,20 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) else x.pack(this->os[i]); } + + this->senders[P.my_num()] = true; +} + +template +void IndividualInput::add_sender(int player) +{ + senders[player] = true; } template void IndividualInput::add_other(int player, int) { - (void) player; + add_sender(player); } template @@ -87,7 +97,26 @@ void IndividualInput::send_mine() template void IndividualInput::exchange() { - P.send_receive_all(os, InputBase::os); + P.send_receive_all(senders, os, InputBase::os); +} + +template +void IndividualInput::start_exchange() +{ + if (senders[P.my_num()]) + for (int offset = 1; offset < P.num_players(); offset++) + P.send_relative(offset, os[P.get_player(offset)]); +} + +template +void IndividualInput::stop_exchange() +{ + for (int offset = 1; offset < P.num_players(); offset++) + { + int receive_from = P.get_player(-offset); + if (senders[receive_from]) + P.receive_player(receive_from, InputBase::os[receive_from]); + } } template diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 8f76d6a79..6bda92dfc 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -33,9 +33,12 @@ class IndirectShamirMC : public MAC_Check_Base template class ShamirMC : public IndirectShamirMC { + typedef typename T::open_type open_type; typedef typename T::open_type::Scalar rec_type; vector reconstruction; + ShamirMC(const ShamirMC&); + void finalize(vector& values, const vector& S); protected: @@ -71,6 +74,7 @@ class ShamirMC : public IndirectShamirMC void Check(const Player& P) { (void)P; } vector get_reconstruction(const Player& P); + open_type reconstruct(const vector& shares); }; #endif /* PROTOCOLS_SHAMIRMC_H_ */ diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 6d6af9136..e3e7cd3ac 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -130,6 +130,19 @@ typename T::open_type ShamirMC::finalize_open() return res; } +template +typename T::open_type ShamirMC::reconstruct(const vector& shares) +{ + assert(reconstruction.size()); + typename T::open_type res; + for (size_t j = 0; j < reconstruction.size(); j++) + { + res += shares[j] * reconstruction[j]; + } + + return res; +} + template void IndirectShamirMC::exchange(const Player& P) { diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 6e818c39f..e7daabfcf 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -13,6 +13,8 @@ template class ReplicatedPrep; template class ReplicatedRingPrep; +template class MaliciousShamirPO; +template class SpecificPrivateOutput; namespace GC { @@ -22,6 +24,8 @@ template class CcdSecret; template class ShamirShare : public T, public ShareInterface { + typedef ShamirShare This; + public: typedef T clear; typedef T open_type; @@ -34,7 +38,8 @@ class ShamirShare : public T, public ShareInterface typedef IndirectShamirMC MAC_Check; typedef ShamirMC Direct_MC; typedef ShamirInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef MaliciousShamirPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef ShamirShare Honest; diff --git a/Protocols/Share.h b/Protocols/Share.h index 743a2c614..92be4f144 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -55,6 +55,7 @@ class Share_ : public ShareInterface const static bool needs_ot = T::needs_ot; const static bool dishonest_majority = T::dishonest_majority; const static bool variable_players = T::variable_players; + const static bool has_mac = true; static int size() { return T::size() + V::size(); } diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index ae6e7b7dd..444214e47 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -34,6 +34,7 @@ class ShareInterface static const bool has_trunc_pr = false; static const bool has_split = false; + static const bool has_mac = false; static const false_type triple_matmul; diff --git a/Protocols/SpdzWiseInput.h b/Protocols/SpdzWiseInput.h index e9597527d..4c5675e91 100644 --- a/Protocols/SpdzWiseInput.h +++ b/Protocols/SpdzWiseInput.h @@ -36,11 +36,8 @@ class SpdzWiseInput : public InputBase void reset(int player); void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player, int n_bits = -1); - void send_mine(); void exchange(); T finalize(int player, int n_bits = -1); - T finalize_mine(); - void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; #endif /* PROTOCOLS_SPDZWISEINPUT_H_ */ diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp index e0d508e51..7aaa14c92 100644 --- a/Protocols/SpdzWiseInput.hpp +++ b/Protocols/SpdzWiseInput.hpp @@ -85,21 +85,3 @@ T SpdzWiseInput::finalize(int player, int) { return shares[player].next(); } - -template -void SpdzWiseInput::send_mine() -{ - throw runtime_error("use exchange()"); -} - -template -T SpdzWiseInput::finalize_mine() -{ - throw runtime_error("use finalize()"); -} - -template -void SpdzWiseInput::finalize_other(int, T&, octetStream&, int) -{ - throw runtime_error("use finalize()"); -} diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h index 9e953e730..9991dafb2 100644 --- a/Protocols/SpdzWiseMC.h +++ b/Protocols/SpdzWiseMC.h @@ -32,7 +32,7 @@ class SpdzWiseMC : public MAC_Check_Base { } - void init_open(const Player& P, int n) + void init_open(const Player& P, int n = 0) { inner_MC.init_open(P, n); } diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index 9cb86017a..1090fc08e 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -15,7 +15,6 @@ #include "Spdz2kPrep.hpp" #include "ShamirMC.hpp" #include "MaliciousRepPO.hpp" -#include "MaliciousShamirPO.hpp" #include "GC/RepPrep.hpp" template diff --git a/Protocols/TemiPrep.h b/Protocols/TemiPrep.h new file mode 100644 index 000000000..de7406bba --- /dev/null +++ b/Protocols/TemiPrep.h @@ -0,0 +1,72 @@ +/* + * TemiPrep.h + * + */ + +#ifndef PROTOCOLS_TEMIPREP_H_ +#define PROTOCOLS_TEMIPREP_H_ + +#include "ReplicatedPrep.h" +#include "FHEOffline/TemiSetup.h" + +template class HemiMatrixPrep; + +template +class TemiMultiplier +{ + typedef typename T::clear::FD FD; + + vector multiplicands; + + Player& P; + +public: + TemiMultiplier(Player& P); + + vector& get_multiplicands( + vector>& ciphertexts, const FHE_PK& pk); + void add(Plaintext_& res, const Ciphertext& C, OT_ROLE role = BOTH, + int n_summands = 1); + + int get_offset() + { + return 0; + } +}; + +/** + * Semi-honest triple generation with semi-homomorphic encryption + */ +template +class TemiPrep : public SemiHonestRingPrep +{ + friend class HemiMatrixPrep; + + typedef typename T::clear::FD FD; + + static Lock lock; + static TemiSetup* setup; + + vector*> multipliers; + +public: + static void basic_setup(Player& P); + static void teardown(); + + static const FD& get_FTD(); + static const FHE_PK& get_pk(); + static const TemiSetup& get_setup(); + + TemiPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), + BitPrep(proc, usage), RingPrep(proc, usage), + SemiHonestRingPrep(proc, usage) + { + } + + void buffer_triples(); + + vector*>& get_multipliers(); +}; + +#endif /* PROTOCOLS_TEMIPREP_H_ */ diff --git a/Protocols/TemiPrep.hpp b/Protocols/TemiPrep.hpp new file mode 100644 index 000000000..1088a99cc --- /dev/null +++ b/Protocols/TemiPrep.hpp @@ -0,0 +1,129 @@ +/* + * TemiPrep.hppg + * + * + */ + +#ifndef PROTOCOLS_TEMIPREP_HPP_ +#define PROTOCOLS_TEMIPREP_HPP_ + +#include "TemiPrep.h" +#include "FHEOffline/SimpleMachine.h" + +#include "FHEOffline/DataSetup.hpp" + +template +TemiSetup* TemiPrep::setup; + +template +Lock TemiPrep::lock; + +template +void TemiPrep::basic_setup(Player& P) +{ + assert(not setup); + setup = new TemiSetup; + MachineBase machine; + setup->secure_init(P, T::clear::length()); + read_or_generate_secrets(*setup, P, machine, 1, true_type()); + T::clear::template init(); +} + +template +void TemiPrep::teardown() +{ + if (setup) + delete setup; +} + +template +const typename T::clear::FD& TemiPrep::get_FTD() +{ + assert(setup); + return setup->FieldD; +} + +template +inline const FHE_PK& TemiPrep::get_pk() +{ + assert(setup); + return setup->pk; +} + +template +const TemiSetup& TemiPrep::get_setup() +{ + assert(setup); + return *setup; +} + +template +void TemiPrep::buffer_triples() +{ + lock.lock(); + if (setup == 0) + { + PlainPlayer P(this->proc->P.N, "Temi" + T::type_string()); + basic_setup(P); + } + lock.unlock(); + + auto& P = this->proc->P; + auto& FieldD = setup->FieldD; + + Plaintext_ a(FieldD), b(FieldD), c(FieldD); + + SeededPRNG G; + a.randomize(G); + b.randomize(G); + + TreeSum ts; + auto C = ts.run(setup->pk.encrypt(a), P); + C = ts.run(C * b + setup->pk.template encrypt(FieldD), P); + c = SimpleDistDecrypt(P, *setup).reshare(C); + + for (unsigned i = 0; i < a.num_slots(); i++) + this->triples.push_back({{a.element(i), b.element(i), c.element(i)}}); +} + +template +vector*>& TemiPrep::get_multipliers() +{ + assert(setup); + assert( + OnlineOptions::singleton.batch_size + <= setup->params.get_matrix_dim()); + assert(this->proc); + if (multipliers.empty()) + multipliers.push_back(new TemiMultiplier(this->proc->P)); + return multipliers; +} + +template +TemiMultiplier::TemiMultiplier(Player& P) : P(P) +{ +} + +template +vector& TemiMultiplier::get_multiplicands( + vector >& ciphertexts, const FHE_PK& pk) +{ + multiplicands.clear(); + multiplicands.resize(ciphertexts[0].size(), pk); + for (size_t j = 0; j < multiplicands.size(); j++) + for (size_t i = 0; i < ciphertexts.size(); i++) + multiplicands[j] += ciphertexts[i].at(j); + return multiplicands; +} + +template +void TemiMultiplier::add(Plaintext_& res, const Ciphertext& C, + OT_ROLE, int) +{ + TreeSum ts; + SimpleDistDecrypt dd(P, TemiPrep::get_setup()); + auto zero = TemiPrep::get_pk().template encrypt(TemiPrep::get_FTD()); + res += dd.reshare(ts.run(C + zero, P)); +} + +#endif /* PROTOCOLS_TEMIPREP_HPP_ */ diff --git a/Protocols/TemiShare.h b/Protocols/TemiShare.h new file mode 100644 index 000000000..f4f37dcd6 --- /dev/null +++ b/Protocols/TemiShare.h @@ -0,0 +1,42 @@ +/* + * TemiShare.h + * + */ + +#ifndef PROTOCOLS_TEMISHARE_H_ +#define PROTOCOLS_TEMISHARE_H_ + +#include "HemiShare.h" + +template class TemiPrep; +template class Hemi; + +template +class TemiShare : public HemiShare +{ + typedef TemiShare This; + typedef HemiShare super; + +public: + typedef SemiMC MAC_Check; + typedef DirectSemiMC Direct_MC; + typedef SemiInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef typename conditional, Beaver>::type Protocol; + typedef TemiPrep LivePrep; + + static const bool needs_ot = false; + static const bool local_mul = false; + + TemiShare() + { + } + template + TemiShare(const U& other) : + super(other) + { + } + +}; + +#endif /* PROTOCOLS_TEMISHARE_H_ */ diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 951cbfe74..45d92613f 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -317,7 +317,14 @@ void read_mac_key(const string& directory, int player_num, int nplayers, U& key) throw mac_key_error(filename); } - key.input(inpf,true); + try + { + key.input(inpf,true); + } + catch(exception&) + { + throw mac_key_error(filename); + } if (inpf.fail()) throw mac_key_error(filename); diff --git a/README.md b/README.md index bd1075121..99d0f0763 100644 --- a/README.md +++ b/README.md @@ -85,10 +85,31 @@ The following table lists all protocols that are fully supported. | --- | --- | --- | --- | --- | | Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | -| Semi-honest, dishonest majority | [Semi / Hemi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | +| Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +Modulo prime and modulo 2^k are the two settings that allow +integer-like computation. For k = 64, the latter corresponds to the +computation available on the widely used 64-bit processors. GF(2^n) +denotes Galois extension fields of order 2^n, which are different to +computation modulo 2^n. In particular, every element has an inverse, +which is not the case modulo 2^n. See [this +article](https://en.wikipedia.org/wiki/Finite_field) for an +introduction. Modulo prime and GF(2^n) are lumped together because the +protocols are very similar due to the mathematical properties. + +Bin. SS stands for binary secret sharing, that is secret sharing +modulo two. In some settings, this requires specific protocols as some +protocols require the domain size to be larger than two. In other +settings, the protocol is the same mathematically speaking, but a +specific implementation allows for optimizations such as using the +inherent parallelism of bit-wise operations on machine words. + +A security model specifies how many parties are "allowed" to misbehave +in what sense. Malicious means that not following the protocol will at +least be detected while semi-honest means that even corrupted parties +are assumed to follow the protocol. See [this paper](https://eprint.iacr.org/2020/300) for an explanation of the various security models and a high-level introduction to multi-party computation. @@ -257,7 +278,9 @@ compute the preprocessing time for a particular computation. add `AVX_OT = 0` in addition. - For optimal results on Linux on ARM, add `ARCH = -march=-march=armv8.2-a+crypto` to `CONFIG.mine`. This enables the - hardware support for AES. + hardware support for AES. See the [GCC + documentation](https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#AArch64-Options) + on available options. - To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to a local, unversioned directory to store preprocessing data (the default is `Player-Data` in the current directory). - For homomorphic encryption with GF(2^40), set `USE_NTL = 1`. @@ -501,6 +524,7 @@ The following table shows all programs for dishonest-majority computation using | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | | `chaigear-party.x` | Adapted [HighGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `chaigear.sh` | | `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | +| `temi-party.x` | Adapted [CDN01](https://eprint.iacr.org/2000/055) | Mod prime | Semi-honest | `temi.sh` | | `soho-party.x` | Somewhat homomorphic encryption | Mod prime | Semi-honest | `soho.sh` | | `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | @@ -538,6 +562,11 @@ Hemi and Soho denote the stripped version version of LowGear and HighGear, respectively, for semi-honest security similar to Semi, that is, generating additively shared Beaver triples using semi-homomorphic encryption. +Temi in turn denotes the adaption of +[Cramer et al.](https://eprint.iacr.org/2000/055) to LWE-based +semi-homomorphic encryption. +Both Hemi and Temi use the diagonal packing by [Halevi and +Shoup](https://eprint.iacr.org/2014/106) for matrix multiplication. We will use MASCOT to demonstrate the use, but the other protocols work similarly. diff --git a/Scripts/prep-usage.py b/Scripts/prep-usage.py new file mode 100755 index 000000000..cb8ca6198 --- /dev/null +++ b/Scripts/prep-usage.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +import sys, os +import collections + +sys.path.append('.') + +from Compiler.program import * +from Compiler.instructions_base import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +res = collections.defaultdict(lambda: 0) +m = 0 + +tapename = next(Program.read_tapes(sys.argv[1])) +res = Tape.ReqNum() +for inst in Tape.read_instructions(tapename): + res.update(inst.get_usage()) + +for x in res.pretty(): + print(x) diff --git a/Scripts/temi.sh b/Scripts/temi.sh new file mode 100755 index 000000000..86f46c548 --- /dev/null +++ b/Scripts/temi.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player temi-party.x $* || exit 1 diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 10fe575f2..e8c02f6cb 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -59,7 +59,7 @@ for dabit in ${dabit:-0 1 2}; do ./compile.py $compile_opts tutorial for i in rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ - atlas mal-shamir sy-shamir hemi semi \ + atlas mal-shamir sy-shamir hemi semi temi \ soho mascot; do test_vm $i $run_opts done diff --git a/Tools/Buffer.h b/Tools/Buffer.h index 941ec4256..ffd411233 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -86,6 +86,10 @@ octetStream check_file_signature(ifstream& file, const string& filename) { throw signature_mismatch(filename); } + catch (IO_Error&) + { + throw signature_mismatch(filename); + } if (file_signature() != file_spec) throw signature_mismatch(filename); return file_spec; diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index 96f69b0c5..f6f4ba2ec 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -35,8 +35,8 @@ wrong_gfp_size::wrong_gfp_size(const char* name, const bigint& p, { } -overflow::overflow(const char* name, size_t i, size_t n) : - runtime_error(string(name) + " overflow: " + to_string(i) + "/" + to_string(n)) +overflow::overflow(const string& name, size_t i, size_t n) : + runtime_error(name + " overflow: " + to_string(i) + "/" + to_string(n)) { } diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index 18406cf6c..fff8b2de4 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -237,7 +237,7 @@ class mac_key_error: public runtime_error class overflow : public runtime_error { public: - overflow(const char* name, size_t i, size_t n); + overflow(const string& name, size_t i, size_t n); }; class unknown_input_type : public runtime_error diff --git a/Tools/octetStream.h b/Tools/octetStream.h index cd90b0e94..676382eaf 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -80,6 +80,8 @@ class octetStream size_t get_ptr() const { return ptr; } /// Length size_t get_length() const { return len; } + /// Length including size tag + size_t get_total_length() const { return len + sizeof(len); } /// Allocation size_t get_max_length() const { return mxlen; } /// Data pointer diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index 45e5f3371..962b27753 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -129,12 +129,12 @@ void run(int argc, char** argv) output.prepare_open(c); } output.exchange(P); + set.check(); cout << "result: "; for (int i = 0; i < n; i++) cout << output.finalize_open() << " "; cout << endl; - protocol.check(); - output.Check(P); + set.check(); } diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp index 532d705e4..a36949d6f 100644 --- a/Utils/mixed-example.cpp +++ b/Utils/mixed-example.cpp @@ -126,12 +126,12 @@ void run(char** argv) output.prepare_open(res); } output.exchange(P); - bit_output.Check(P); + set.check(); cout << "result: "; for (int i = 0; i < n; i++) cout << output.finalize_open() << " "; cout << endl; - output.Check(P); + set.check(); } diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 9cae6953f..83571c218 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -110,7 +110,7 @@ void run(char** argv, int prime_length) c = protocol.finalize_dotprod(n); // protocol check before revealing results - protocol.check(); + set.check(); output.init_open(P); output.prepare_open(c); @@ -120,5 +120,5 @@ void run(char** argv, int prime_length) cout << "result: " << result << endl; // result check after opening - output.Check(P); + set.check(); } diff --git a/doc/instructions.rst b/doc/instructions.rst index 1a833994e..fb62066ed 100644 --- a/doc/instructions.rst +++ b/doc/instructions.rst @@ -85,12 +85,10 @@ Compiler.instructions module .. automodule:: Compiler.instructions :members: :no-undoc-members: - :exclude-members: asm_input, inputmask, lts, print_char4_regint, - print_char_regint, protectmemc, sqrs, - start_grind, startprivateoutput, stop_grind, - stopprivateoutput, writesocketc, writesocketint, - protectmemint, protectmems, print_mem, - matmul_base, g2muls, inputmixed_base, raw_output + :exclude-members: asm_input, sqrs, + start_grind, stop_grind, + writesocketc, writesocketint, + matmul_base, inputmixed_base, raw_output Compiler.GC.instructions module ------------------------------- diff --git a/doc/low-level.rst b/doc/low-level.rst index c70bf5b65..7f5474fd4 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -309,6 +309,11 @@ Share Types - ``SpdzWiseShare`` - `SPDZ-wise `_. ``T`` must be ``MaliciousShamirShare`` or ``MaliciousRep3Share``. + * + - ``TemiShare`` + - Semi-honest protocol with Beaver multiplication based on + threshold semi-homomorphic encryption. ``T`` must be + ``gfp_`` or ``gf2n_short``. Protocol Setup diff --git a/doc/non-linear.rst b/doc/non-linear.rst index bcdbbd3ae..e5df4c204 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -88,7 +88,7 @@ The following table lists the matching arithmetic and binary protocols. cut-and-choose analysis by `Furukawa et al. `_ * - - Semi, Hemi, Soho, Semi2k + - Semi, Hemi, Temi, Soho, Semi2k - SemiBin (Beaver triples modulo 2 using OT) * - `Malicious Shamir `_ diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 1441e3524..21500c455 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -85,7 +85,37 @@ modulo the default 128-bit prime 00000025 -``Fake-Offline.x`` generates preprocessing data insecurely for a range -of protocols, and ``{mascot,cowgear,mal-shamir}-offline.x`` generate +The actual data is stored is by simple concatenation. For example, +triples are stored as repetitions of ``a, b, ab``, and daBits are +stored as repetitions of ``a, b`` where ``a`` is the arithmetic +share and ``b`` is the binary share. + +For protocols with MAC, the value share is stored before the MAC +share. + +Values are generally stored in little-endian order. Note the following +domain specifics: + +Modulo a prime + Values are stored in `Montgomery representation + `_ + with :math:`R` being the smallest power of :math:`2^{64}` larger than + the prime. For example, :math:`R = 2^{128}` for a 128-bit prime. + Furthermore, the values are stored in the smallest number of 8-byte + blocks necessary, all in little-endian order. + +Modulo a power of two: + Values are stored in the smallest number of 8-byte blocks necessary, + all in little-endian order. + +:math:`GF(2^n)` + Values are stored in blocks according to the storage size above, + all in little-endian order. + +For further details, have a look at ``Utils/Fake-Offline.cpp``, which +contains code that generates preprocessing data insecurely for a range +of protocols (underlying the binary ``Fake-Offline.x``). + +``{mascot,cowgear,mal-shamir}-offline.x`` generate sufficient preprocessing data for a specific high-level program with MASCOT, CowGear, and malicious Shamir secret sharing, respectively. diff --git a/doc/requirements.txt b/doc/requirements.txt index cd6467ed8..32add0c79 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1 +1,2 @@ breathe +sphinx-rtd-theme==0.5.2 From 08ea9b3bd0b33aa5331c9da7f3f5d788d5ffa19b Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 22 Feb 2022 13:25:12 +1100 Subject: [PATCH 030/221] Better scaling network setup. --- Networking/Player.cpp | 26 +++++++++----------------- Networking/Server.cpp | 13 +++++++------ Networking/Server.h | 2 +- Tools/octetStream.cpp | 15 +++++++++++++++ Tools/octetStream.h | 5 +++++ 5 files changed, 37 insertions(+), 24 deletions(-) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index cd92df541..b4bab177f 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -146,25 +146,17 @@ void Names::setup_names(const char *servername, int my_port) #endif // Now get the set of names - int i; - size_t tmp; - receive(socket_num,tmp,4); - nplayers = tmp; -#ifdef VERBOSE - cerr << nplayers << " players\n"; -#endif - names.resize(nplayers); - ports.resize(nplayers); - for (i=0; i void octetStream::exchange(T send_socket, T receive_socket, octetStream& receive_stream) const { diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 676382eaf..96af81917 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -207,6 +207,11 @@ class octetStream s.len=l; } + /// Append string + void store(const string& str); + /// Read string + void get(string& str); + /// Send on ``socket_num`` template void Send(T socket_num) const; From 9c3e607068084f8fff716c510a47ae51a854c3fe Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 24 Feb 2022 11:57:21 +1100 Subject: [PATCH 031/221] Bug when inputting to large arrays. --- Compiler/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 1dbe1f909..a81a2562d 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5409,7 +5409,7 @@ def input_from(self, player, budget=None, raw=False): input_from = self.value_type.get_input_from try: self.assign(input_from(player, size=len(self))) - except TypeError: + except (TypeError, CompilerError): @library.for_range_opt(len(self), budget=budget) def _(i): self[i] = input_from(player) From 6664de3f77bd7abdfdb9cb8e9b01f4e082d59a91 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 1 Mar 2022 17:08:25 +1100 Subject: [PATCH 032/221] Multiplication of matrices larger than the maximum register size. --- Compiler/types.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index a81a2562d..c090eef97 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5917,7 +5917,22 @@ class t(self.value_type): res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: try: - res_matrix.assign_vector(self.direct_mul(other)) + if res_matrix.total_size() < _register.maximum_size: + res_matrix.assign_vector(self.direct_mul(other)) + else: + slice = _register.maximum_size // res_matrix.sizes[1] + assert slice > 0 + n = res_matrix.sizes[0] // slice + @library.for_range_opt(n) + def _(i): + res_matrix.assign_part_vector( + self.get_part(i * slice, + slice).direct_mul(other), + i * slice) + base = n * slice + rem = self.sizes[0] - base + res_matrix.assign_part_vector( + self.get_part(base, rem).direct_mul(other), base) except AttributeError: if max(res_matrix.sizes) > 1000: raise AttributeError() From 60dd78797e0e3d93f1e2c02388603f56366eaf46 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 2 Mar 2022 19:25:07 +1100 Subject: [PATCH 033/221] Multithreaded matrix multiplication. --- Compiler/types.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/Compiler/types.py b/Compiler/types.py index c090eef97..7f87b905e 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5885,7 +5885,7 @@ def mul(self, other, res_params=None): # legacy function return self.dot(other, res_params) - def dot(self, other, res_params=None): + def dot(self, other, res_params=None, n_threads=None): """ Matrix-matrix and matrix-vector multiplication. :param self: two-dimensional @@ -5917,22 +5917,12 @@ class t(self.value_type): res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: try: - if res_matrix.total_size() < _register.maximum_size: - res_matrix.assign_vector(self.direct_mul(other)) - else: - slice = _register.maximum_size // res_matrix.sizes[1] - assert slice > 0 - n = res_matrix.sizes[0] // slice - @library.for_range_opt(n) - def _(i): - res_matrix.assign_part_vector( - self.get_part(i * slice, - slice).direct_mul(other), - i * slice) - base = n * slice - rem = self.sizes[0] - base + self.value_type.direct_matrix_mul + max_size = _register.maximum_size // res_matrix.sizes[1] + @library.multithread(n_threads, self.sizes[0], max_size) + def _(base, size): res_matrix.assign_part_vector( - self.get_part(base, rem).direct_mul(other), base) + self.get_part(base, size).direct_mul(other), base) except AttributeError: if max(res_matrix.sizes) > 1000: raise AttributeError() From 9c0c94dfe2792585e76ba680b9faef146cd997c8 Mon Sep 17 00:00:00 2001 From: shareong <740310627@qq.com> Date: Fri, 4 Mar 2022 17:48:43 +0800 Subject: [PATCH 034/221] update options --- Networking/Player.cpp | 29 +++++++---------------------- Networking/Player.h | 7 +++---- Processor/OnlineMachine.hpp | 12 +++++++++++- Tools/NetworkOptions.cpp | 16 +++++++++++++--- Tools/NetworkOptions.h | 2 +- 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index f3064616a..03709fec7 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -43,43 +43,28 @@ void Names::init(int player,int pnb,vector Nms) } // initialize names from file, no Server.x coordination. -void Names::init(int player, int pnb, const string& filename, int nplayers_wanted) +void Names::init(int player, int pnb, const string& gateway, int nplayers_wanted) { - ifstream hostsfile(filename.c_str()); - if (hostsfile.fail()) - { - stringstream ss; - ss << "Error opening " << filename << ". See HOSTS.example for an example."; - throw file_error(ss.str().c_str()); - } player_no = player; - nplayers = 0; + nplayers = nplayers_wanted; portnum_base = pnb; - string line; ports.clear(); - while (getline(hostsfile, line)) + for (int i = 0; i < nplayers; i++) { - if (line.length() > 0 && line.at(0) != '#') { - auto pos = line.find(':'); + auto pos = gateway.find(':'); if (pos == string::npos) { - names.push_back(line); + names.push_back(gateway); ports.push_back(default_port(nplayers)); } else { - names.push_back(line.substr(0, pos)); + names.push_back(gateway.substr(0, pos)); int port; - stringstream(line.substr(pos + 1)) >> port; + stringstream(gateway.substr(pos + 1)) >> port; ports.push_back(port); } - nplayers++; - if (nplayers_wanted > 0 and nplayers_wanted == nplayers) - break; - } } - if (nplayers_wanted > 0 and nplayers_wanted != nplayers) - throw runtime_error("not enought hosts in HOSTS"); #ifdef DEBUG_NETWORKING cerr << "Got list of " << nplayers << " players from file: " << endl; for (unsigned int i = 0; i < names.size(); i++) diff --git a/Networking/Player.h b/Networking/Player.h index e686397ae..9796dca75 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -93,12 +93,11 @@ class Names * Initialize from file. One party per line, format ``[:]`` * @param player my number * @param pnb base port number - * @param hostsfile filename * @param players number of players (0 to take from file) */ - void init(int player, int pnb, const string& hostsfile, int players = 0); - Names(int player, int pnb, const string& hostsfile) : Names() - { init(player, pnb, hostsfile); } + void init(int player, int pnb, const string& gateway, int players = 0); + Names(int player, int pnb, const string& gateway) : Names() + { init(player, pnb, gateway); } /** * Initialize from command-line options diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 58e91724f..95b3515d0 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -80,6 +80,16 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op "--ip-file-name" // Flag token. ); + opt.add( + "", // Default. + 1, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "gateway to connect", // Help description. + "-gateway", // Flag token. + "--gateway-endpoint" // Flag token. + ); + if (nplayers == 0) opt.add( "2", // Default. @@ -181,7 +191,7 @@ void OnlineMachine::start_networking() opt.get("--portnumbase")->getInt(pnbase); opt.get("--hostname")->getString(hostname); - opt.get("--ip-file-name")->getString(ipFileName); + opt.get("--gateway-endpoint")->getString(ipFileName); ez::OptionGroup* mp_opt = opt.get("--my-port"); if (mp_opt->isSet) diff --git a/Tools/NetworkOptions.cpp b/Tools/NetworkOptions.cpp index aa939cf44..9cb8081ad 100644 --- a/Tools/NetworkOptions.cpp +++ b/Tools/NetworkOptions.cpp @@ -60,6 +60,16 @@ NetworkOptionsWithNumber::NetworkOptionsWithNumber(ez::ezOptionParser& opt, "--ip-file-name" // Flag token. ); + opt.add( + "", // Default. + 1, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "gateway to connect", // Help description. + "-gateway", // Flag token. + "--gateway-endpoint" // Flag token. + ); + opt.parse(argc, argv); if (variable_nplayers) @@ -67,16 +77,16 @@ NetworkOptionsWithNumber::NetworkOptionsWithNumber(ez::ezOptionParser& opt, else nplayers = default_nplayers; - opt.get("-ip")->getString(ip_filename); + opt.get("-gateway")->getString(gateway); opt.resetArgs(); } Server* NetworkOptionsWithNumber::start_networking(Names& N, int my_num) { - if (ip_filename.length() > 0) + if (gateway.length() > 0) { - N.init(my_num, portnum_base, ip_filename, nplayers); + N.init(my_num, portnum_base, gateway, nplayers); return 0; } else diff --git a/Tools/NetworkOptions.h b/Tools/NetworkOptions.h index 8a74271da..448ab0f90 100644 --- a/Tools/NetworkOptions.h +++ b/Tools/NetworkOptions.h @@ -25,7 +25,7 @@ class NetworkOptionsWithNumber : NetworkOptions { public: int nplayers; - std::string ip_filename; + std::string gateway; NetworkOptionsWithNumber(ez::ezOptionParser& opt, int argc, const char** argv, int default_nplayers, bool variable_nplayers); From e485aacd37b7e2f8901b1bee5a4ec132a536a043 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 2 Mar 2022 12:32:13 +1100 Subject: [PATCH 035/221] Easier change of domain in SPDZ2k. --- Machines/SPDZ2k.cpp | 10 ++++++++++ Machines/spdz2k-party.cpp | 22 +++++++++++++++++++--- Processor/RingMachine.hpp | 13 +++++++++---- Protocols/Spdz2kShare.h | 4 ++++ 4 files changed, 42 insertions(+), 7 deletions(-) create mode 100644 Machines/SPDZ2k.cpp diff --git a/Machines/SPDZ2k.cpp b/Machines/SPDZ2k.cpp new file mode 100644 index 000000000..62a4324e2 --- /dev/null +++ b/Machines/SPDZ2k.cpp @@ -0,0 +1,10 @@ +/* + * SPDZ2k.cpp + * + */ + +#include "SPDZ2k.hpp" + +#ifdef RING_SIZE +template class Machine, Share>; +#endif diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index 188b62912..8aba31737 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -10,7 +10,7 @@ #include "Math/gf2n.h" #include "Networking/Server.h" -#include "Processor/OnlineMachine.hpp" +#include "Processor/RingMachine.hpp" #include "Math/Z2k.hpp" int main(int argc, const char** argv) @@ -46,7 +46,23 @@ int main(int argc, const char** argv) Z(72, 64) Z(72, 48) +#ifdef RING_SIZE + Z(RING_SIZE, SPDZ2K_DEFAULT_SECURITY) +#endif + else - throw runtime_error( - "not compiled for k=" + to_string(k) + " and s=" + to_string(s)); + { + if (s == SPDZ2K_DEFAULT_SECURITY) + { + ring_domain_error(k); + } + else + { + cerr << "not compiled for k=" << k << " and s=" << s << "," << endl; + cerr << "add Z(" << k << ", " << s << ") to " << __FILE__ << " at line " + << (__LINE__ - 11) << " and create Machines/SPDZ2^" << k << "+" + << s << ".cpp based on Machines/SPDZ2^72+64.cpp" << endl; + } + exit(1); + } } diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index e422e0aa5..f2bfc6c1b 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -30,6 +30,13 @@ HonestMajorityRingMachine::HonestMajorityRingMachine(int argc, const char* RingMachine(argc, argv, opt, online_opts, nplayers); } +inline void ring_domain_error(int R) +{ + cerr << "not compiled for " << R << "-bit computation, " << endl; + cerr << "compile with -DRING_SIZE=" << R << endl; + exit(1); +} + template class U, template class V, class W> RingMachine::RingMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers) @@ -49,8 +56,7 @@ RingMachine::RingMachine(int argc, const char** argv, #endif #undef X default: - cerr << "not compiled for " << to_string(R) + "-bit computation" << endl; - exit(1); + ring_domain_error(R); } } @@ -88,8 +94,7 @@ HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecuri #endif #undef X default: - cerr << "not compiled for " << to_string(R) + "-bit computation" << endl; - exit(1); + ring_domain_error(R); } } diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index d26cde4f2..401070f84 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -6,6 +6,10 @@ #ifndef PROTOCOLS_SPDZ2KSHARE_H_ #define PROTOCOLS_SPDZ2KSHARE_H_ +#ifndef SPDZ2K_DEFAULT_SECURITY +#define SPDZ2K_DEFAULT_SECURITY 64 +#endif + #include "Math/Z2k.h" #include "Protocols/Share.h" #include "Protocols/MAC_Check.h" From 0501a2701cc11376e063817c8731871e47eae835 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 8 Mar 2022 17:05:44 +1100 Subject: [PATCH 036/221] Document domain types. --- Math/Z2k.h | 47 ++++++++++++++++++++++++++++++++++- Math/gfp.h | 63 ++++++++++++++++++++++++++++++++++++++++++++--- Math/gfpvar.h | 8 ++++++ doc/Doxyfile | 2 +- doc/low-level.rst | 16 ++++++++++++ 5 files changed, 131 insertions(+), 5 deletions(-) diff --git a/Math/Z2k.h b/Math/Z2k.h index ad32cbf16..586c78c06 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -19,6 +19,12 @@ using namespace std; template class IntBase; template class fixint; +/** + * Type for values in the ring defined by the integers modulo ``2^K`` + * representing `[0, 2^K-1]`. + * It supports arithmetic, bit-wise, and output streaming operations. + * It does not need initialization because ``K`` completely defines the domain. + */ template class Z2 : public ValueInterface { @@ -71,6 +77,9 @@ class Z2 : public ValueInterface typedef Z2 next; typedef Z2 Scalar; + /** + * Initialize to zero. + */ Z2() { assign_zero(); } Z2(mp_limb_t x) : Z2() { a[0] = x; } Z2(__m128i x) : Z2() { avx_memcpy(a, &x, min(N_BYTES, 16)); } @@ -78,8 +87,14 @@ class Z2 : public ValueInterface Z2(long x) : Z2(mp_limb_t(x)) { if (K > 64 and x < 0) memset(&a[1], -1, N_BYTES - 8); } template Z2(const IntBase& x); + /** + * Convert from unrestricted integer. + */ Z2(const bigint& x); Z2(const void* buffer) : Z2() { assign(buffer); } + /** + * Convert from different domain via the canonical integer representation. + */ template Z2(const Z2& x) : Z2() { avx_memcpy(a, x.a, min(N_BYTES, x.N_BYTES)); normalize(); } @@ -140,19 +155,38 @@ class Z2 : public ValueInterface Z2 invert() const; + /** + * Deterministic square root for values with least significate bit 1. + * Raises an exception otherwise. + */ Z2 sqrRoot(); bool is_zero() const { return *this == Z2(); } bool is_one() const { return *this == 1; } bool is_bit() const { return is_zero() or is_one(); } + /** + * Sample with uniform distribution. + * @param G randomness generator + * @param n (unused) + */ void randomize(PRNG& G, int n = -1); void randomize_part(PRNG& G, int n); void almost_randomize(PRNG& G) { randomize(G); } void force_to_bit() { throw runtime_error("impossible"); } + /** + * Append to buffer in native format. + * @param o buffer + * @param n (unused) + */ void pack(octetStream& o, int = -1) const; + /** + * Read from buffer in native format + * @param o buffer + * @param n (unused) + */ void unpack(octetStream& o, int n = -1); void input(istream& s, bool human=true); @@ -162,21 +196,32 @@ class Z2 : public ValueInterface friend ostream& operator<<(ostream& o, const Z2& x); }; +/** + * Type for values in the ring defined by the integers modulo ``2^K`` + * representing `[-2^(K-1), 2^(K-1)-1]`. + * It supports arithmetic, bit-wise, comparison, and output streaming operations. + * It does not need initialization because ``K`` completely defines the domain. + */ template class SignedZ2 : public Z2 { public: + /** + * Initialization to zero + */ SignedZ2() { } + /** + * Conversion from another domain via the signed representation + */ template SignedZ2(const SignedZ2& other) : Z2(other) { extend(other); } - template void extend(const SignedZ2& other) { diff --git a/Math/gfp.h b/Math/gfp.h index bde43025e..3bc23e194 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -40,6 +40,16 @@ template void generate_prime_setup(string, int, int); #error GFP_MOD_SZ must be at most MAX_MOD_SZ #endif +/** + * Type for values in a field defined by integers modulo a prime + * in a specific range for fixed storage. + * It supports basic arithmetic operations and bit-wise operations. + * The latter use the canonical representation in the range `[0, p-1]`. + * ``X`` is a counter to allow several moduli being used at the same time. + * ``L`` is the number of 64-bit limbs, that is, + * the prime has to have bit length in `[64*L-63, 64*L]`. + * See ``gfpvar_`` for a more flexible alternative. + */ template class gfp_ : public ValueInterface { @@ -72,7 +82,17 @@ class gfp_ : public ValueInterface template static void init(bool mont = true) { init_field(T::pr(), mont); } + /** + * Initialize the field. + * @param p: prime modulus + * @param mont: whether to use Montgomery representation + */ static void init_field(const bigint& p,bool mont=true); + /** + * Initialize the field to a prime of a given bit length. + * @param lgp: bit length + * @param mont: whether to use Montgomery representation + */ static void init_default(int lgp, bool mont = true); static void read_or_generate_setup(string dir, const OnlineOptions& opts); template @@ -85,6 +105,9 @@ class gfp_ : public ValueInterface { write_online_setup(dir, pr()); } static void check_setup(string dir); + /** + * Get the prime modulus + */ static const bigint& pr() { return ZpD.pr; } static int t() @@ -126,15 +149,24 @@ class gfp_ : public ValueInterface const void* get_ptr() const { return &a.x; } void* get_ptr() { return &a.x; } + /** + * Initialize to zero. + */ gfp_() { assignZero(a,ZpD); } template gfp_(const modp_& g) { a=g; } + /** + * Convert from integer without range restrictions. + */ gfp_(const mpz_class& x) { to_modp(a, x, ZpD); } gfp_(int x) : gfp_(long(x)) {} gfp_(long x); gfp_(word x) : gfp_(bigint::tmp = x) {} template gfp_(IntBase x) : gfp_(x.get()) {} + /** + * Convert from different domain via canonical integer representation. + */ template gfp_(const gfp_& x); gfp_(const gfpvar& other); @@ -181,9 +213,16 @@ class gfp_ : public ValueInterface void negate() { Negate(a,a,ZpD); } - // deterministic square root + /** + * Deterministic square root. + */ gfp_ sqrRoot(); + /** + * Sample with uniform distribution. + * @param G randomness generator + * @param n (unused) + */ void randomize(PRNG& G, int n = -1) { (void) n; a.randomize(G,ZpD); } // faster randomization, see implementation for explanation @@ -194,10 +233,20 @@ class gfp_ : public ValueInterface void input(istream& s,bool human) { a.input(s,ZpD,human); } + /** + * Human-readable output in the range `[-p/2, p/2]`. + * @param s output stream + * @param x value + */ friend ostream& operator<<(ostream& s,const gfp_& x) { x.output(s,true); return s; } + /** + * Human-readable input without range restrictions + * @param s input stream + * @param x value + */ friend istream& operator>>(istream& s,gfp_& x) { x.input(s,true); return s; @@ -220,10 +269,18 @@ class gfp_ : public ValueInterface void force_to_bit() { throw runtime_error("impossible"); } - // Pack and unpack in native format - // i.e. Dont care about conversion to human readable form + /** + * Append to buffer in native format. + * @param o buffer + * @param n (unused) + */ void pack(octetStream& o, int n = -1) const { (void) n; a.pack(o); } + /** + * Read from buffer in native format + * @param o buffer + * @param n (unused) + */ void unpack(octetStream& o, int n = -1) { (void) n; a.unpack(o); } diff --git a/Math/gfpvar.h b/Math/gfpvar.h index 438a935e5..a3b475f8c 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -14,6 +14,14 @@ class FFT_Data; template class BitVec_; +/** + * Type for values in a field defined by integers modulo a prime + * up to a certain length for fixed storage. + * ``X`` is a counter to allow several moduli being used at the same time. + * ``L`` is the maximum number of 64-bit limbs, that is, + * the prime has to have bit length at most `64*L`. + * The interface replicates ``gfp_``. + */ template class gfpvar_ { diff --git a/doc/Doxyfile b/doc/Doxyfile index 3dd299405..36837c385 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/low-level.rst b/doc/low-level.rst index 7f5474fd4..fd9d2bfc4 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -355,3 +355,19 @@ Protocol Interfaces .. doxygenclass:: BufferPrep :members: + + +Domain Reference +---------------- + +.. doxygenclass:: gfp_ + :members: + +.. doxygenclass:: gfpvar_ + :members: + +.. doxygenclass:: Z2 + :members: + +.. doxygenclass:: SignedZ2 + :members: From 6a223a6b99340e269c8f627fb888891d6aa6990a Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 8 Mar 2022 10:01:19 +0100 Subject: [PATCH 037/221] RTD build. --- doc/Doxyfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/Doxyfile b/doc/Doxyfile index 36837c385..9820ba50c 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -801,7 +801,7 @@ WARN_NO_PARAMDOC = NO # a warning is encountered. # The default value is: NO. -WARN_AS_ERROR = YES +WARN_AS_ERROR = NO # The WARN_FORMAT tag determines the format of the warning messages that doxygen # can produce. The string should contain the $file, $line, and $text tags, which From c040a54f634f141b59046f9742a73af095a30441 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Wed, 9 Mar 2022 20:23:08 -0600 Subject: [PATCH 038/221] Fix typo in docs --- doc/networking.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/networking.rst b/doc/networking.rst index a1c61b98d..c7e031f10 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -13,7 +13,7 @@ base port number, which can be changed using the same option. There are two ways of communicating hosts and individually setting ports: -1. All parties first to connect to a coordination server, which +1. All parties first connect to a coordination server, which broadcasts the data for all parties. This is the default with the coordination server being run as a thread of party 0. The hostname of the coordination server has to be given with the command-line From b283fdb385c0777c5faf7a63557a415955a50995 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 9 Mar 2022 18:25:59 +1100 Subject: [PATCH 039/221] Improved multi-threaded tree reduction. --- Compiler/library.py | 10 ++++++++++ Compiler/types.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/Compiler/library.py b/Compiler/library.py index 3f31499b0..46d72ec7b 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1340,6 +1340,16 @@ def summer(*args): return map_reduce(n_threads, 1, n_loops, initializer, summer) def tree_reduce_multithread(n_threads, function, vector): + """ Round-efficient reduction in several threads. The following code + computes the maximum of an array in 10 threads:: + + tree_reduce_multithread(10, lambda x, y: x.max(y), a) + + :param n_threads: number of threads (int) + :param function: reduction function taking exactly two arguments + :param vector: register vector or array + + """ inputs = vector.Array(len(vector)) inputs.assign_vector(vector) outputs = vector.Array(len(vector) // 2) diff --git a/Compiler/types.py b/Compiler/types.py index 7f87b905e..b63597338 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5554,6 +5554,10 @@ def sort(self, n_threads=None): """ library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + def Array(self, size): + # compatibility with registers + return Array(size, self.value_type) + def __str__(self): return '%s array of length %s at %s' % (self.value_type, len(self), self.address) From 1227376ae3620180be956e3846b08c73343f50b1 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 14 Mar 2022 23:41:18 +1100 Subject: [PATCH 040/221] Bug in conversion from secret integer to secret bits. --- Compiler/comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 2f7ca81f5..23bee2190 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -292,7 +292,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): """ program.curr_tape.require_bit_length(k + kappa) from .types import sint - if program.use_edabit() and m > 1 and not const_rounds: + if program.use_edabit() and not const_rounds: movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0]) tmp, b[:] = sint.get_edabit(m, True) movs(r_prime, tmp) From 5b248ced927b89f2d357c02fa27faee7712e45ad Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 16 Mar 2022 12:12:14 +1100 Subject: [PATCH 041/221] Bug in negative sbits input. --- GC/ShareSecret.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 23c86cb28..267b6f839 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -162,7 +162,7 @@ void Processor::inputb(typename T::Input& input, ProcessorBase& input_process for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) { auto& res = S[x.dest + i]; - res.my_input(input, bigint(whole_input >> (i * dl)).get_ui(), + res.my_input(input, bigint(whole_input >> (i * dl)).get_si(), min(dl, x.n_bits - i * dl)); } } From 28341241e0721359aa040102d9cfb52472b26e2d Mon Sep 17 00:00:00 2001 From: HaoXuan40404 <444649358@qq.com> Date: Thu, 31 Mar 2022 15:19:51 +0800 Subject: [PATCH 042/221] fix fuzzy search in database --- Compiler/ppc.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/Compiler/ppc.py b/Compiler/ppc.py index 4950b143d..88dc8f2d9 100644 --- a/Compiler/ppc.py +++ b/Compiler/ppc.py @@ -333,5 +333,15 @@ def pint_floordiv(self, other): return to_pint(pint_div(self, other)) +def write_matrix_to_file(matrix, matrix_row, matrix_col): + array_result = Array(matrix_row * matrix_col, matrix.value_type) + @for_range(matrix_row) + def _(i): + @for_range(matrix_col) + def _(j): + array_result[i*matrix_col+j] = matrix[i][j] + array_result.write_to_file() + + pint.__mod__ = pint_mod pint.__floordiv__ = pint_floordiv From ddeef9360427d3df1498056aea6e90c6143b259c Mon Sep 17 00:00:00 2001 From: HaoXuan40404 <444649358@qq.com> Date: Thu, 31 Mar 2022 19:39:33 +0800 Subject: [PATCH 043/221] add write matrix to file --- Compiler/ppc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Compiler/ppc.py b/Compiler/ppc.py index 88dc8f2d9..de1ab8a3d 100644 --- a/Compiler/ppc.py +++ b/Compiler/ppc.py @@ -334,12 +334,14 @@ def pint_floordiv(self, other): def write_matrix_to_file(matrix, matrix_row, matrix_col): - array_result = Array(matrix_row * matrix_col, matrix.value_type) + array_result = Array(matrix_row * matrix_col + 2, matrix.value_type) + array_result[0] = matrix.value_type(matrix_row) + array_result[1] = matrix.value_type(matrix_col) @for_range(matrix_row) def _(i): @for_range(matrix_col) def _(j): - array_result[i*matrix_col+j] = matrix[i][j] + array_result[2 + i*matrix_col+j] = matrix[i][j] array_result.write_to_file() From 07292ec09dbfcfc996802e4e2f88e943b253ccb1 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 31 Mar 2022 18:30:12 +0200 Subject: [PATCH 044/221] Bug in integer division. --- Compiler/library.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index 46d72ec7b..06503c7b8 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1848,7 +1848,8 @@ def block(i): return (sign_a * sign_b) * A def IntDiv(a, b, k, kappa=None): - return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k, + l = 2 * k + 1 + return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k, kappa, nearest=True) @instructions_base.ret_cisc From 565c364cd4204a8d697c7ab3d235774a15ecb29e Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 2 Apr 2022 07:41:22 +0200 Subject: [PATCH 045/221] Bug in fixed-point division. --- Compiler/library.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Compiler/library.py b/Compiler/library.py index 06503c7b8..35d6f46f2 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1876,19 +1876,20 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() - y = a.extend(2 *k) * w - y = y.round(2*k, f, kappa, nearest, signed=True) + l_y = k + 3 * f - res_f + y = a.extend(l_y) * w + y = y.round(l_y, f, kappa, nearest, signed=True) for i in range(theta - 1): x = x.extend(2 * k) - y = y.extend(2 * k) * (alpha + x).extend(2 * k) + y = y.extend(l_y) * (alpha + x).extend(l_y) x = x * x - y = y.round(2*k, 2*f, kappa, nearest, signed=True) + y = y.round(l_y, 2*f, kappa, nearest, signed=True) x = x.round(2*k, 2*f, kappa, nearest, signed=True) x = x.extend(2 * k) - y = y.extend(2 * k) * (alpha + x).extend(2 * k) - y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) + y = y.extend(l_y) * (alpha + x).extend(l_y) + y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True) return y def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False): From b0e7857cbc29dfa12d5339abf114e83f10b9baf9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 1 Apr 2022 20:28:41 +0200 Subject: [PATCH 046/221] Store MAC keys for persistence. --- Processor/Machine.hpp | 9 +++++---- Protocols/fake-stuff.h | 2 +- Protocols/fake-stuff.hpp | 4 +++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index cd318f1aa..e720b2a99 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -6,6 +6,7 @@ #include "Memory.hpp" #include "Online-Thread.hpp" #include "Protocols/Hemi.hpp" +#include "Protocols/fake-stuff.hpp" #include "Tools/Exceptions.h" @@ -60,10 +61,10 @@ Machine::Machine(int my_number, Names& playerNames, sint::LivePrep::basic_setup(*P); } - sint::read_or_generate_mac_key(prep_dir_prefix(), *P, alphapi); - sgf2n::read_or_generate_mac_key(prep_dir_prefix(), *P, alpha2i); - sint::bit_type::part_type::read_or_generate_mac_key( - prep_dir_prefix(), *P, alphabi); + alphapi = read_generate_write_mac_key(*P); + alpha2i = read_generate_write_mac_key(*P); + alphabi = read_generate_write_mac_key(*P); #ifdef DEBUG_MAC cerr << "MAC Key p = " << alphapi << endl; diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 0c209869b..d15581ebd 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -34,7 +34,7 @@ template void read_mac_key(const string& directory, const Names& N, U& key); template -typename T::mac_key_type read_generate_write_mac_key(const Player& P, +typename T::mac_key_type read_generate_write_mac_key(Player& P, string directory = ""); template diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 45d92613f..aeb516118 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -272,7 +272,9 @@ void write_mac_key(const string& directory, int i, int nplayers, U key) ofstream outf; stringstream filename; filename << mac_filename(directory, i); +#ifdef VERBOSE cout << "Writing to " << filename.str().c_str() << endl; +#endif outf.open(filename.str().c_str()); outf << nplayers << endl; key.output(outf,true); @@ -333,7 +335,7 @@ void read_mac_key(const string& directory, int player_num, int nplayers, U& key) } template -inline typename T::mac_key_type read_generate_write_mac_key(const Player& P, +typename T::mac_key_type read_generate_write_mac_key(Player& P, string directory) { if (directory == "") From f42930edc306a5145ba3d80fe10b7a2e0e452925 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 2 Apr 2022 08:06:03 +0200 Subject: [PATCH 047/221] Sufficient offline preprocessing with several threads. --- Processor/OfflineMachine.h | 3 +++ Processor/OfflineMachine.hpp | 15 +++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Processor/OfflineMachine.h b/Processor/OfflineMachine.h index 792c5bdaa..d1b142569 100644 --- a/Processor/OfflineMachine.h +++ b/Processor/OfflineMachine.h @@ -18,10 +18,13 @@ class OfflineMachine : public W BaseMachine machine; Names& playerNames; Player& P; + int n_threads; template void generate(); + int buffered_total(size_t required, size_t batch); + public: template OfflineMachine(int argc, const char** argv, diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index dcfafe553..6e0bb525d 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -22,6 +22,7 @@ OfflineMachine::OfflineMachine(int argc, const char** argv, Program program(playerNames.num_players()); program.parse(machine.bc_filenames[0]); usage = program.get_offline_data_used(); + n_threads = machine.nthreads; machine.ot_setups.push_back({P}); } @@ -52,6 +53,12 @@ int OfflineMachine::run() return 0; } +template +int OfflineMachine::buffered_total(size_t required, size_t batch) +{ + return DIV_CEIL(required, batch) * batch + (n_threads - 1) * batch; +} + template template void OfflineMachine::generate() @@ -79,7 +86,7 @@ void OfflineMachine::generate() if (i == DATA_DABIT) { for (long long j = 0; - j < DIV_CEIL(my_usage, BUFFER_SIZE) * BUFFER_SIZE; j++) + j < buffered_total(my_usage, BUFFER_SIZE); j++) { T a; typename T::bit_type b; @@ -91,7 +98,7 @@ void OfflineMachine::generate() { vector tuple(DataPositions::tuple_size[i]); for (long long j = 0; - j < DIV_CEIL(my_usage, BUFFER_SIZE) * BUFFER_SIZE; j++) + j < buffered_total(my_usage, BUFFER_SIZE); j++) { preprocessing.get(dtype, tuple.data()); for (auto& x : tuple) @@ -113,7 +120,7 @@ void OfflineMachine::generate() file_signature().output(out); InputTuple tuple; for (long long j = 0; - j < DIV_CEIL(n_inputs, BUFFER_SIZE) * BUFFER_SIZE; j++) + j < buffered_total(n_inputs, BUFFER_SIZE); j++) { preprocessing.get_input(tuple.share, tuple.value, i); tuple.share.output(out, false); @@ -142,7 +149,7 @@ void OfflineMachine::generate() { ofstream out(filename, ios::binary); file_signature().output(out); - for (int i = 0; i < DIV_CEIL(total, batch) * batch; i++) + for (int i = 0; i < buffered_total(total, batch); i++) preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits, out); } From 06f3f21cee7ba02af9ed68144cd49be0cb2340bf Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Thu, 10 Feb 2022 23:43:12 -0600 Subject: [PATCH 048/221] Add SSL_DIR env var --- CONFIG | 3 +++ Networking/CryptoPlayer.cpp | 6 +++--- Networking/ssl_sockets.h | 8 ++++++-- README.md | 3 ++- Scripts/setup-ssl.sh | 7 ++++--- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/CONFIG b/CONFIG index 05b3683d5..bf92327ed 100644 --- a/CONFIG +++ b/CONFIG @@ -8,6 +8,9 @@ GDEBUG = -g # set this to your preferred local storage directory PREP_DIR = '-DPREP_DIR="Player-Data/"' +# directory to store SSL keys +SSL_DIR = '-DSSL_DIR="Player-Data/"' + # set for SHE preprocessing (SPDZ and Overdrive) USE_NTL = 0 diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 9d8da6514..43b2ada5c 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -19,9 +19,9 @@ void ssl_error(string side, string other, string me) { cerr << side << "-side handshake with " << other << " failed. Make sure both sides " - << " have the necessary certificate (" << PREP_DIR << me + << " have the necessary certificate (" << SSL_DIR << me << ".pem in the default configuration on their side and " - << PREP_DIR << other << ".pem on ours)," + << SSL_DIR << other << ".pem on ours)," << " and run `c_rehash ` on its location." << endl << "The certificates should be the same on every host. " << "Also make sure that it's still valid. Certificates generated " @@ -36,7 +36,7 @@ void ssl_error(string side, string other, string me) cerr << "Signature (should match the other side): "; for (int i = 0; i < 2; i++) { - auto filename = PREP_DIR + ids[i] + ".pem"; + auto filename = SSL_DIR + ids[i] + ".pem"; ifstream cert(filename); stringstream buffer; buffer << cert.rdbuf(); diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index 79cb35222..fe9477a81 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -14,6 +14,10 @@ #include #include +#ifndef SSL_DIR +#define SSL_DIR "Player-Data/" +#endif + typedef boost::asio::io_service ssl_service; void check_ssl_file(string filename); @@ -25,7 +29,7 @@ class ssl_ctx : public boost::asio::ssl::context ssl_ctx(string me) : boost::asio::ssl::context(boost::asio::ssl::context::tlsv12) { - string prefix = PREP_DIR + me; + string prefix = SSL_DIR + me; string cert_file = prefix + ".pem"; string key_file = prefix + ".key"; check_ssl_file(cert_file); @@ -33,7 +37,7 @@ class ssl_ctx : public boost::asio::ssl::context use_certificate_file(cert_file, pem); use_private_key_file(key_file, pem); - add_verify_path(PREP_DIR); + add_verify_path(SSL_DIR); } }; diff --git a/README.md b/README.md index 99d0f0763..4f9b44567 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,7 @@ compute the preprocessing time for a particular computation. on available options. - To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to a local, unversioned directory to store preprocessing data (the default is `Player-Data` in the current directory). + - `SSL_DIR` should point to a local, unversioned directory to store ssl keys (the default is `Player-Data` in the current directory). - For homomorphic encryption with GF(2^40), set `USE_NTL = 1`. 2) Run `make` to compile all the software (use the flag `-j` for faster @@ -707,7 +708,7 @@ information. MP-SPDZ uses OpenSSL for secure channels. You can generate the necessary certificates and keys as follows: -`Scripts/setup-ssl.sh []` +`Scripts/setup-ssl.sh [ ]` The programs expect the keys and certificates to be in `Player-Data/P.key` and `Player-Data/P.pem`, respectively, and diff --git a/Scripts/setup-ssl.sh b/Scripts/setup-ssl.sh index ffd79bf0d..01113f166 100755 --- a/Scripts/setup-ssl.sh +++ b/Scripts/setup-ssl.sh @@ -4,13 +4,14 @@ PATH=/usr/local/opt/openssl/bin:$PATH n=${1:-4} +ssl_dir=${2:-"Player-Data"} -test -e Player-Data || mkdir Player-Data +test -e $ssl_dir || mkdir $ssl_dir echo Setting up SSL for $n parties for i in `seq 0 $[n-1]`; do - openssl req -newkey rsa -nodes -x509 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i" + openssl req -newkey rsa -nodes -x509 -out $ssl_dir/P$i.pem -keyout $ssl_dir/P$i.key -subj "/CN=P$i" done -c_rehash Player-Data +c_rehash $ssl_dir From 68d6eb0832af9076e317b046c9f3deb69e06724b Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Fri, 8 Apr 2022 20:54:12 -0500 Subject: [PATCH 049/221] Add .env to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ce68ca4e4..9a4dd72e2 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ Thumbs.db # Sphinx build _build/ + +# environment +.env From 9b4e0447eb5a1a55233970462f4cdc92d9d7ba84 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Tue, 12 Apr 2022 20:54:35 -0500 Subject: [PATCH 050/221] Add missing SSL_DIR cflag in CONFIG --- CONFIG | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONFIG b/CONFIG index bf92327ed..cef15e0b4 100644 --- a/CONFIG +++ b/CONFIG @@ -87,7 +87,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) From 5773930bbfe085e69a171d8a2cbd128b9ffa0f8a Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Tue, 12 Apr 2022 21:44:08 -0500 Subject: [PATCH 051/221] Emphasize that ssl keys must be under SSL_DIR --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4f9b44567..59a90a112 100644 --- a/README.md +++ b/README.md @@ -708,16 +708,18 @@ information. MP-SPDZ uses OpenSSL for secure channels. You can generate the necessary certificates and keys as follows: -`Scripts/setup-ssl.sh [ ]` +`Scripts/setup-ssl.sh [ ]` The programs expect the keys and certificates to be in -`Player-Data/P.key` and `Player-Data/P.pem`, respectively, and +`SSL_DIR/P.key` and `SSL_DIR/P.pem`, respectively, and the certificates to have the common name `P` for player ``. Furthermore, the relevant root certificates have to be in -`Player-Data` such that OpenSSL can find them (run `c_rehash -Player-Data`). The script above takes care of all this by generating +`SSL_DIR` such that OpenSSL can find them (run `c_rehash +`). The script above takes care of all this by generating self-signed certificates. Therefore, if you are running the programs on different hosts you will need to copy the certificate files. +Note that `` must match `SSL_DIR` set in `CONFIG` or `CONFIG.mine`. +Just like `SSL_DIR`, `` defaults to `Player-Data`. In the following, we will walk through running the tutorial modulo 2^k with three parties. The other programs work similarly. From 30508236a772cacda146e4f518125c438617202c Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Fri, 8 Apr 2022 15:16:34 -0500 Subject: [PATCH 052/221] Add Dockerfile --- Dockerfile | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..e01dd5c53 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,149 @@ +############################################################################### +# Build this stage for a build environment, e.g.: # +# # +# docker build --tag mpspdz:buildenv --target buildenv . # +# # +# The above is equivalent to: # +# # +# docker build --tag mpspdz:buildenv \ # +# --target buildenv \ # +# --build-arg arch=native \ # +# --build-arg cxx=clang++-11 \ # +# --build-arg use_ntl=0 \ # +# --build-arg prep_dir="Player-Data" \ # +# --build-arg ssl_dir="Player-Data" # +# --build-arg cryptoplayers=0 # +# # +# To build for an x86-64 architecture, with g++, NTL (for HE), custom # +# prep_dir & ssl_dir, and to use encrypted channels for 4 players: # +# # +# docker build --tag mpspdz:buildenv \ # +# --target buildenv \ # +# --build-arg arch=x86-64 \ # +# --build-arg cxx=g++ \ # +# --build-arg use_ntl=1 \ # +# --build-arg prep_dir="/opt/prepdata" \ # +# --build-arg ssl_dir="/opt/ssl" # +# --build-arg cryptoplayers=4 . # +# # +# To work in a container to build different machines, and compile programs: # +# # +# docker run --rm -it mpspdz:buildenv bash # +# # +# Once in the container, build a machine and compile a program: # +# # +# $ make replicated-ring-party.x # +# $ ./compile.py -R 64 tutorial # +# # +############################################################################### +FROM python:3.10.3-bullseye as buildenv + +RUN apt-get update && apt-get install -y --no-install-recommends \ + automake \ + build-essential \ + clang-11 \ + git \ + libboost-dev \ + libboost-thread-dev \ + libclang-dev \ + libntl-dev \ + libsodium-dev \ + libssl-dev \ + libtool \ + m4 \ + texinfo \ + yasm \ + vim \ + gdb \ + valgrind \ + && rm -rf /var/lib/apt/lists/* + +# mpir +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/include/* /usr/local/include/ +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/lib/* /usr/local/lib/ +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/share/info/* /usr/local/share/info/ + +ENV MP_SPDZ_HOME /usr/src/MP-SPDZ +WORKDIR $MP_SPDZ_HOME + +RUN pip install --upgrade pip ipython + +COPY . . + +ARG arch=native +ARG cxx=clang++-11 +ARG use_ntl=0 +ARG prep_dir="Player-Data" +ARG ssl_dir="Player-Data" + +RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \ + && echo "CXX = ${cxx}" >> CONFIG.mine \ + && echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \ + && echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \ + && echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \ + >> CONFIG.mine \ + && mkdir -p $prep_dir $ssl_dir \ + && echo "PREP_DIR = '-DPREP_DIR=\"${prep_dir}/\"'" >> CONFIG.mine \ + && echo "SSL_DIR = '-DSSL_DIR=\"${ssl_dir}/\"'" >> CONFIG.mine + +# ssl keys +ARG cryptoplayers=0 +ENV PLAYERS ${cryptoplayers} +RUN ./Scripts/setup-ssl.sh ${cryptoplayers} ${ssl_dir} + + +############################################################################### +# Use this stage to a build a specific virtual machine. For example: # +# # +# docker build --tag mpspdz:shamir \ # +# --target machine \ # +# --build-arg machine=shamir-party.x \ # +# --build-arg gfp_mod_sz=4 . # +# # +# The above will build shamir-party.x with 256 bit length. # +# # +# If no build arguments are passed (via --build-arg), mascot-party.x is built # +# with the default 128 bit length. # +############################################################################### +FROM buildenv as machine + +ARG machine="mascot-party.x" + +ARG gfp_mod_sz=2 + +RUN echo "MOD = -DGFP_MOD_SZ=${gfp_mod_sz}" >> CONFIG.mine + +RUN make clean && make ${machine} && cp ${machine} /usr/local/bin/ + + +################################################################################ +# This is the default stage. Use it to compile a high-level program. # +# By default, tutorial.mpc is compiled with --field=64 bits. # +# # +# docker build --tag mpspdz:mascot-tutorial \ # +# --build-arg src=tutorial \ # +# --build-arg compile_options="--field=64" . # +# # +# Note that build arguments from previous stages can also be passed. For # +# instance, building replicated-ring-party.x, for 3 crypto players with custom # +# PREP_DIR and SSL_DIR, and compiling tutorial.mpc with --ring=64: # +# # +# docker build --tag mpspdz:replicated-ring \ # +# --build-arg machine=replicated-ring-party.x \ # +# --build-arg prep_dir=/opt/prep \ # +# --build-arg ssl_dir=/opt/ssl \ # +# --build-arg nparties=3 \ # +# --build-arg compile_options="--ring=64" . # +# # +# Test it: # +# # +# docker run --rm -it mpspdz:replicated-ring ./Scripts/ring.sh tutorial # +################################################################################ +FROM machine as program + +ARG src="tutorial" +ARG compile_options="--field=64" +RUN ./compile.py ${compile_options} ${src} +RUN mkdir -p Player-Data \ + && echo 1 2 3 4 > Player-Data/Input-P0-0 \ + && echo 1 2 3 4 > Player-Data/Input-P1-0 From 4d17b4f38957306e5c105dd4e097a0138a60f889 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Mon, 11 Apr 2022 15:18:07 -0500 Subject: [PATCH 053/221] Add tl;dr for docker in readme --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 4f9b44567..b5b4445ee 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,21 @@ echo 1 2 3 4 > Player-Data/Input-P1-0 Scripts/mascot.sh tutorial ``` +#### TL;DR (Docker) +Build a docker image for `mascot-party.x`: + +``` +docker build --tag mpspdz:mascot-party --build-arg machine=mascot-party.x . +``` + +Run the [the tutorial](Programs/Source/tutorial.mpc): + +``` +docker run --rm -it mpspdz:mascot-party ./Scripts/mascot.sh tutorial +``` + +See the [`Dockerfile`](./Dockerfile) for examples of how it can be used. + #### Preface The primary aim of this software is to run the same computation in From 24bf33aba24d55959fb0fb6fa6550c8c2a127bf1 Mon Sep 17 00:00:00 2001 From: Bishakh Ghosh Date: Fri, 15 Apr 2022 15:41:51 +0530 Subject: [PATCH 054/221] Fix typo in ECDSA readme --- ECDSA/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ECDSA/README.md b/ECDSA/README.md index 6307a91e6..7cb014ce8 100644 --- a/ECDSA/README.md +++ b/ECDSA/README.md @@ -24,8 +24,8 @@ The following binaries have been used for the paper: All binaries offer the same interface. With MASCOT for example, run the following: ``` -./mascot-ecsda-party.x -p 0 [-N ] [-h ] [-D] [] -./mascot-ecsda-party.x -p 1 [-N ] [-h ] [-D] [] +./mascot-ecdsa-party.x -p 0 [-N ] [-h ] [-D] [] +./mascot-ecdsa-party.x -p 1 [-N ] [-h ] [-D] [] ... ``` From 9ef15cc2f56d1aa335d9884e3f7bb75be0eed2af Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 19 Apr 2022 15:10:18 +0200 Subject: [PATCH 055/221] Protocol in dealer model. --- CHANGELOG.md | 10 ++ Compiler/GC/instructions.py | 8 +- Compiler/GC/types.py | 15 +- Compiler/allocator.py | 4 +- Compiler/instructions.py | 12 +- Compiler/instructions_base.py | 13 +- Compiler/library.py | 13 ++ Compiler/ml.py | 7 +- Compiler/program.py | 3 + Compiler/types.py | 35 ++++- ECDSA/P256Element.h | 1 + ECDSA/semi-ecdsa-party.cpp | 1 + ExternalIO/Client.h | 4 +- ExternalIO/Client.hpp | 6 +- ExternalIO/bankers-bonus-client.cpp | 18 +-- FHE/FHE_Params.cpp | 23 ++- FHE/FHE_Params.h | 5 +- FHE/Matrix.cpp | 13 +- FHE/NTL-Subs.cpp | 73 ++++++--- FHE/NTL-Subs.h | 9 +- FHE/NoiseBounds.cpp | 18 ++- FHE/Subroutines.cpp | 38 ++--- FHEOffline/DataSetup.cpp | 8 +- FHEOffline/DataSetup.h | 2 +- FHEOffline/Multiplier.cpp | 2 +- FHEOffline/PairwiseMachine.cpp | 4 +- FHEOffline/PairwiseSetup.cpp | 5 +- FHEOffline/Proof.h | 1 + FHEOffline/Sacrificing.cpp | 9 +- FHEOffline/SimpleMachine.cpp | 7 +- FHEOffline/SimpleMachine.h | 1 - FHEOffline/TemiSetup.cpp | 4 +- GC/CcdShare.h | 2 +- GC/DealerPrep.h | 69 +++++++++ GC/FakeSecret.h | 2 +- GC/MaliciousCcdShare.h | 2 +- GC/MaliciousRepSecret.h | 4 +- GC/NoShare.h | 16 +- GC/Processor.hpp | 2 +- GC/SemiSecret.h | 106 +++++++++---- GC/{SemiSecret.cpp => SemiSecret.hpp} | 42 ++++-- GC/ShareSecret.h | 6 +- GC/ThreadMaster.hpp | 9 +- GC/TinyMC.h | 4 +- Machines/SPDZ2k.hpp | 1 + Machines/Semi.hpp | 1 + Machines/ShamirMachine.hpp | 1 - Machines/TripleMachine.cpp | 4 +- Machines/dealer-ring-party.cpp | 22 +++ Machines/emulate.cpp | 5 +- Machines/hemi-party.cpp | 1 + Machines/malicious-ccd-party.cpp | 3 +- Machines/semi-bin-party.cpp | 1 + Machines/soho-party.cpp | 1 + Machines/spdz2k-party.cpp | 8 +- Machines/temi-party.cpp | 1 + Machines/tinier-party.cpp | 4 +- Machines/tiny-party.cpp | 2 +- Makefile | 26 ++-- Math/Integer.h | 2 - Math/Setup.cpp | 15 +- Math/Setup.h | 1 + Math/bigint.h | 4 +- Math/gf2n.cpp | 35 ++++- Math/gf2n.h | 3 +- Networking/AllButLastPlayer.h | 67 +++++++++ Networking/Player.cpp | 32 +++- Networking/Player.h | 30 ++-- OT/NPartyTripleGenerator.hpp | 9 +- Processor/BaseMachine.cpp | 2 + Processor/Data_Files.hpp | 4 +- Processor/Input.h | 4 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 6 +- Processor/Instruction.hpp | 42 +++--- Processor/Machine.h | 19 ++- Processor/Machine.hpp | 135 ++++++++++++----- Processor/OfflineMachine.hpp | 7 +- Processor/Online-Thread.hpp | 5 +- Processor/OnlineMachine.h | 4 +- Processor/OnlineMachine.hpp | 44 +----- Processor/OnlineOptions.cpp | 49 ++++++- Processor/OnlineOptions.h | 9 +- Processor/OnlineOptions.hpp | 42 +++++- Processor/Program.h | 4 +- Processor/RingMachine.hpp | 16 +- Processor/RingOptions.cpp | 19 +-- Processor/RingOptions.h | 4 +- Processor/instructions.h | 4 +- Programs/Source/l2h_comparison.mpc | 3 + Programs/Source/l2h_multiplication.mpc | 1 + Protocols/Beaver.h | 1 + Protocols/ChaiGearPrep.hpp | 9 +- Protocols/CowGearOptions.cpp | 14 +- Protocols/CowGearOptions.h | 1 - Protocols/CowGearPrep.hpp | 7 +- Protocols/DabitSacrifice.h | 6 +- Protocols/DabitSacrifice.hpp | 6 + Protocols/DealerInput.h | 38 +++++ Protocols/DealerInput.hpp | 115 +++++++++++++++ Protocols/DealerMC.h | 42 ++++++ Protocols/DealerMC.hpp | 76 ++++++++++ Protocols/DealerPrep.h | 33 +++++ Protocols/DealerPrep.hpp | 196 +++++++++++++++++++++++++ Protocols/DealerShare.h | 76 ++++++++++ Protocols/FakeMC.h | 2 +- Protocols/FakeProtocol.h | 4 + Protocols/FakeShare.h | 3 +- Protocols/Hemi.h | 1 + Protocols/Hemi.hpp | 3 +- Protocols/HemiMatrixPrep.h | 11 +- Protocols/HemiMatrixPrep.hpp | 1 + Protocols/HemiPrep.hpp | 4 +- Protocols/MAC_Check.h | 3 +- Protocols/MAC_Check.hpp | 39 +---- Protocols/MAC_Check_Base.h | 3 +- Protocols/MAC_Check_Base.hpp | 10 +- Protocols/MalRepRingOptions.cpp | 4 +- Protocols/MalRepRingPrep.hpp | 2 +- Protocols/MaliciousRepMC.h | 6 +- Protocols/MaliciousRepMC.hpp | 4 +- Protocols/MaliciousRepPrep.hpp | 4 +- Protocols/MaliciousShamirMC.h | 2 +- Protocols/MaliciousShamirMC.hpp | 2 +- Protocols/MaliciousShamirShare.h | 3 + Protocols/MamaPrep.hpp | 12 +- Protocols/MamaShare.h | 5 + Protocols/NoLivePrep.h | 5 + Protocols/NoProtocol.h | 1 + Protocols/NoShare.h | 3 - Protocols/Rep3Share.h | 7 +- Protocols/Replicated.h | 2 + Protocols/Replicated.hpp | 13 +- Protocols/ReplicatedMC.h | 6 +- Protocols/ReplicatedMC.hpp | 2 +- Protocols/SPDZ.h | 3 +- Protocols/SPDZ2k.h | 28 ++++ Protocols/Semi2kShare.h | 2 - Protocols/SemiInput.h | 4 +- Protocols/SemiInput.hpp | 2 +- Protocols/SemiShare.h | 11 +- Protocols/ShamirMC.h | 2 +- Protocols/ShamirMC.hpp | 4 +- Protocols/ShamirShare.h | 37 +---- Protocols/Share.h | 10 +- Protocols/Share.hpp | 7 + Protocols/ShareInterface.h | 5 + Protocols/ShareMatrix.h | 2 +- Protocols/ShuffleSacrifice.hpp | 5 +- Protocols/SohoPrep.hpp | 1 + Protocols/Spdz2kPrep.h | 3 +- Protocols/Spdz2kShare.h | 3 +- Protocols/SpdzWise.hpp | 3 + Protocols/SpdzWiseMC.h | 4 +- Protocols/SpdzWiseShare.hpp | 14 +- Protocols/TemiPrep.h | 2 + Protocols/TemiPrep.hpp | 8 + Protocols/config.h | 13 ++ Protocols/fake-stuff.h | 2 +- Protocols/fake-stuff.hpp | 57 ++++++- Protocols/mac_key.hpp | 8 + README.md | 24 ++- Scripts/dealer-ring.sh | 10 ++ Scripts/memory-usage.py | 9 +- Scripts/test_tutorial.sh | 6 +- Scripts/tldr.sh | 6 +- Tools/Buffer.cpp | 5 + Tools/Exceptions.cpp | 5 + Tools/Exceptions.h | 6 + Tools/avx_memcpy.h | 9 +- Tools/benchmarking.cpp | 19 +++ Tools/benchmarking.h | 15 +- Tools/intrinsics.h | 2 + Tools/parse.h | 8 + Utils/Check-Offline.cpp | 12 +- Utils/Fake-Offline.cpp | 34 ++--- Utils/binary-example.cpp | 1 + Utils/l2h-example.cpp | 54 +++++++ Utils/mixed-example.cpp | 1 + Utils/paper-example.cpp | 5 +- Yao/YaoEvalWire.cpp | 2 +- Yao/YaoGarbleWire.cpp | 2 +- Yao/YaoPlayer.cpp | 2 +- Yao/YaoWire.hpp | 2 +- doc/Compiler.rst | 10 +- doc/non-linear.rst | 10 ++ 186 files changed, 2008 insertions(+), 618 deletions(-) create mode 100644 GC/DealerPrep.h rename GC/{SemiSecret.cpp => SemiSecret.hpp} (61%) create mode 100644 Machines/dealer-ring-party.cpp create mode 100644 Networking/AllButLastPlayer.h create mode 100644 Programs/Source/l2h_comparison.mpc create mode 100644 Programs/Source/l2h_multiplication.mpc create mode 100644 Protocols/DealerInput.h create mode 100644 Protocols/DealerInput.hpp create mode 100644 Protocols/DealerMC.h create mode 100644 Protocols/DealerMC.hpp create mode 100644 Protocols/DealerPrep.h create mode 100644 Protocols/DealerPrep.hpp create mode 100644 Protocols/DealerShare.h create mode 100644 Protocols/SPDZ2k.h create mode 100644 Protocols/config.h create mode 100755 Scripts/dealer-ring.sh create mode 100644 Utils/l2h-example.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a0406a8e..744d0ff1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.1 (Apr 19, 2022) + +- Protocol in dealer model +- Command-line option for security parameter +- Fixed security bug in SPDZ2k (see Section 3.4 of [the updated paper](https://eprint.iacr.org/2018/482)) +- Ability to run high-level (Python) code from C++ +- More memory capacity due to 64-bit addressing +- Homomorphic encryption for more fields of characteristic two +- Docker container + ## 0.3.0 (Feb 17, 2022) - Semi-honest computation based on threshold semi-homomorphic encryption diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index ef9c14a3f..e53b71879 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -305,7 +305,7 @@ class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, :param: memory address (int) """ code = opcodes['LDMSB'] - arg_format = ['sbw','int'] + arg_format = ['sbw','long'] class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction): """ Copy secret bit register to secret bit memory cell with compile-time @@ -315,7 +315,7 @@ class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction): :param: memory address (int) """ code = opcodes['STMSB'] - arg_format = ['sb','int'] + arg_format = ['sb','long'] # def __init__(self, *args, **kwargs): # super(type(self), self).__init__(*args, **kwargs) # import inspect @@ -330,7 +330,7 @@ class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, :param: memory address (int) """ code = opcodes['LDMCB'] - arg_format = ['cbw','int'] + arg_format = ['cbw','long'] class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): """ Copy clear bit register to clear bit memory cell with compile-time @@ -340,7 +340,7 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): :param: memory address (int) """ code = opcodes['STMCB'] - arg_format = ['cb','int'] + arg_format = ['cb','long'] class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): """ Copy secret bit memory cell with run-time address to secret bit diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 38c37a261..6c3abad0a 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -3,6 +3,9 @@ fixed-length types obtained by :py:obj:`get_type(n)` are the preferred way of using them, and in some cases required in connection with container types. + +Computation using these types will always be executed as a binary +circuit. See :ref:`protocol-pairs` for the exact protocols. """ from Compiler.types import MemValue, read_mem_value, regint, Array, cint @@ -17,7 +20,6 @@ from functools import reduce class bits(Tape.Register, _structure, _bit): - """ Base class for binary registers. """ n = 40 unit = 64 PreOp = staticmethod(floatingpoint.PreOpN) @@ -400,12 +402,18 @@ def get_random_bit(): res = sbit() inst.bitb(res) return res + @staticmethod + def _check_input_player(player): + if not util.is_constant(player): + raise CompilerError('player must be known at compile time ' + 'for binary circuit inputs') @classmethod def get_input_from(cls, player, n_bits=None): """ Secret input from :py:obj:`player`. :param: player (int) """ + cls._check_input_player(player) if n_bits is None: n_bits = cls.n res = cls() @@ -653,6 +661,7 @@ def get_input_from(cls, player): :param: player (int) """ + sbits._check_input_player(player) res = cls.from_vec(sbit() for i in range(n)) inst.inputbvec(n + 3, 0, player, *res.v) return res @@ -780,6 +789,8 @@ def coerce(self, other): size = other.size return (other.get_vector(base, min(64, size - base)) \ for base in range(0, size, 64)) + if not isinstance(other, type(self)): + return type(self)(other) return other def __xor__(self, other): other = self.coerce(other) @@ -1222,6 +1233,7 @@ def get_input_from(cls, player): :param: player (int) """ + sbits._check_input_player(player) v = cls.int_type() inst.inputb(player, cls.k, cls.f, v) return cls._new(v) @@ -1287,6 +1299,7 @@ def get_input_from(cls, player): :param: player (int) """ v = [sbit() for i in range(sbitfix.k)] + sbits._check_input_player(player) inst.inputbvec(len(v) + 3, sbitfix.f, player, *v) return cls._new(cls.int_type.from_vec(v)) def __init__(self, value=None, *args, **kwargs): diff --git a/Compiler/allocator.py b/Compiler/allocator.py index cf2f13ef4..bf431ca38 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -15,11 +15,11 @@ class BlockAllocator: """ Manages freed memory blocks. """ def __init__(self): - self.by_logsize = [defaultdict(set) for i in range(32)] + self.by_logsize = [defaultdict(set) for i in range(64)] self.by_address = {} def by_size(self, size): - if size >= 2 ** 32: + if size >= 2 ** 64: raise CompilerError('size exceeds addressing capability') return self.by_logsize[int(math.log(size, 2))][size] diff --git a/Compiler/instructions.py b/Compiler/instructions.py index e06797684..8a10ee58c 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -69,7 +69,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): """ __slots__ = [] code = base.opcodes['LDMC'] - arg_format = ['cw','int'] + arg_format = ['cw','long'] @base.gf2n @base.vectorize @@ -84,7 +84,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): """ __slots__ = [] code = base.opcodes['LDMS'] - arg_format = ['sw','int'] + arg_format = ['sw','long'] @base.gf2n @base.vectorize @@ -99,7 +99,7 @@ class stmc(base.DirectMemoryWriteInstruction): """ __slots__ = [] code = base.opcodes['STMC'] - arg_format = ['c','int'] + arg_format = ['c','long'] @base.gf2n @base.vectorize @@ -114,7 +114,7 @@ class stms(base.DirectMemoryWriteInstruction): """ __slots__ = [] code = base.opcodes['STMS'] - arg_format = ['s','int'] + arg_format = ['s','long'] @base.vectorize class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): @@ -128,7 +128,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): """ __slots__ = [] code = base.opcodes['LDMINT'] - arg_format = ['ciw','int'] + arg_format = ['ciw','long'] @base.vectorize class stmint(base.DirectMemoryWriteInstruction): @@ -142,7 +142,7 @@ class stmint(base.DirectMemoryWriteInstruction): """ __slots__ = [] code = base.opcodes['STMINT'] - arg_format = ['ci','int'] + arg_format = ['ci','long'] @base.vectorize class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction): diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index d6c647add..8ae0b86fc 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -337,7 +337,7 @@ def reformat(arg_format): if isinstance(arg_format, list): __format = [] for __f in arg_format: - if __f in ('int', 'p', 'ci', 'str'): + if __f in ('int', 'long', 'p', 'ci', 'str'): __format.append(__f) else: __format.append(__f[0] + 'g' + __f[1:]) @@ -360,7 +360,7 @@ class GF2N_Instruction(instruction_cls): arg_format = instruction_cls.gf2n_arg_format elif isinstance(instruction_cls.arg_format, itertools.repeat): __f = next(instruction_cls.arg_format) - if __f != 'int' and __f != 'p': + if __f not in ('int', 'long', 'p'): arg_format = itertools.repeat(__f[0] + 'g' + __f[1:]) else: arg_format = copy.deepcopy(instruction_cls.arg_format) @@ -711,6 +711,14 @@ def __init__(self, f): def __str__(self): return str(self.i) +class LongArgFormat(IntArgFormat): + @classmethod + def encode(cls, arg): + return struct.pack('>Q', arg) + + def __init__(self, f): + self.i = struct.unpack('>Q', f.read(8))[0] + class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): @@ -768,6 +776,7 @@ def __str__(self): 'i': ImmediateModpAF, 'ig': ImmediateGF2NAF, 'int': IntArgFormat, + 'long': LongArgFormat, 'p': PlayerNoAF, 'str': String, } diff --git a/Compiler/library.py b/Compiler/library.py index 35d6f46f2..ef2fe1ab6 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1149,6 +1149,7 @@ def multithread(n_threads, n_items=None, max_size=None): :param n_threads: compile-time (int) :param n_items: regint/cint/int (default: :py:obj:`n_threads`) + :param max_size: maximum size to be processed at once (default: no limit) The following executes ``f(0, 8)``, ``f(8, 8)``, and ``f(16, 9)`` in three different threads: @@ -1366,6 +1367,18 @@ def _(base, size): left = (left + 1) // 2 return inputs[0] +def tree_reduce(function, sequence): + """ Round-efficient reduction. The following computes the maximum + of the list :py:obj:`l`:: + + m = tree_reduce(lambda x, y: x.max(y), l) + + :param function: reduction function taking two arguments + :param sequence: list, vector, or array + + """ + return util.tree_reduce(function, sequence) + def foreach_enumerate(a): """ Run-time loop over public data. This uses ``Player-Data/Public-Input/``. Example: diff --git a/Compiler/ml.py b/Compiler/ml.py index c521934fe..02f0f04ed 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1079,14 +1079,17 @@ def __repr__(self): (type(self).__name__, self.X.sizes, self.strides, self.ksize, self.padding) - def _forward(self, batch): + def forward(self, batch=None, training=False): + if batch is None: + batch = Array.create_from(regint(0)) def process(pool, bi, k, i, j): def m(a, b): c = a[0] > b[0] l = [c * x for x in a[1]] l += [(1 - c) * x for x in b[1]] return c.if_else(a[0], b[0]), l - red = util.tree_reduce(m, [(x[0], [1]) for x in pool]) + red = util.tree_reduce(m, [(x[0], [1] if training else []) + for x in pool]) self.Y[bi][i][j][k] = red[0] for i, x in enumerate(red[1]): self.comparisons[bi][k][i] = x diff --git a/Compiler/program.py b/Compiler/program.py index 366723304..78b802e14 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -400,6 +400,9 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) + if addr + size >= 2 ** 32: + raise CompilerError("allocation exceeded for type '%s'" % + mem_type) self.allocated_mem_blocks[addr,mem_type] = size if single_size: from .library import get_thread_number, runtime_error_if diff --git a/Compiler/types.py b/Compiler/types.py index b63597338..99ca6a8c6 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1710,7 +1710,12 @@ def reveal_to(self, player): res = Array.create_from(res) return personal(player, res) - def bit_decompose(self, length): + def bit_decompose(self, length=None): + """ Bit decomposition. + + :param length: number of bits + + """ return [personal(self.player, x) for x in self._v.bit_decompose(length)] def _san(self, other): @@ -2144,7 +2149,7 @@ class sint(_secret, _int): the bit length. :param val: initialization (sint/cint/regint/int/cgf2n or list - thereof or sbits/sbitvec/sfix) + thereof, sbits/sbitvec/sfix, or :py:class:`personal`) :param size: vector size (int), defaults to 1 or size of list When converting :py:class:`~Compiler.GC.types.sbits`, the result is a @@ -2152,6 +2157,9 @@ class sint(_secret, _int): :py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values with bit length equal the length of the input. + Initializing from a :py:class:`personal` value implies the + relevant party inputting their value securely. + """ __slots__ = [] instruction_type = 'modp' @@ -4285,6 +4293,7 @@ class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` for security considerations of the underlying integer operations. + The secret integer is stored as the :py:obj:`v` member. It supports basic arithmetic (``+, -, *, /``), returning :py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``), @@ -5121,7 +5130,8 @@ class Array(_vectorizable): array ``a`` and ``i`` being a :py:class:`regint`, :py:class:`cint`, or a Python integer. - :param length: compile-time integer (int) or :py:obj:`None` for unknown length + :param length: compile-time integer (int) or :py:obj:`None` + for unknown length (need to specify :py:obj:`address`) :param value_type: basic type :param address: if given (regint/int), the array will not be allocated @@ -5178,6 +5188,8 @@ def delete(self): self.address = None def get_address(self, index): + if isinstance(index, (_secret, _single)): + raise CompilerError('need cleartext index') key = str(index) if self.length is not None: from .GC.types import cbits @@ -5211,6 +5223,7 @@ def get_slice(self, index): if index.step == 0: raise CompilerError('slice step cannot be zero') return index.start or 0, \ + index.stop if self.length is None else \ min(index.stop or self.length, self.length), index.step or 1 def __getitem__(self, index): @@ -5517,7 +5530,15 @@ def print_reveal_nested(self, end='\n'): :param end: string to print after (default: line break) """ - library.print_str('%s' + end, self.get_vector().reveal()) + if util.is_constant(self.length): + library.print_str('%s' + end, self.get_vector().reveal()) + else: + library.print_str('[') + @library.for_range(self.length - 1) + def _(i): + library.print_str('%s, ', self[i].reveal()) + library.print_str('%s', self[self.length - 1].reveal()) + library.print_str(']' + end) def reveal_to_binary_output(self, player=None): """ Reveal to binary output if supported by type. @@ -5893,7 +5914,8 @@ def dot(self, other, res_params=None, n_threads=None): """ Matrix-matrix and matrix-vector multiplication. :param self: two-dimensional - :param other: Matrix or Array of matching size and type """ + :param other: Matrix or Array of matching size and type + :param n_threads: number of threads (default: all in same thread) """ assert len(self.sizes) == 2 if isinstance(other, Array): assert len(other) == self.sizes[1] @@ -5928,6 +5950,7 @@ def _(base, size): res_matrix.assign_part_vector( self.get_part(base, size).direct_mul(other), base) except AttributeError: + assert n_threads is None if max(res_matrix.sizes) > 1000: raise AttributeError() A = self.get_vector() @@ -5937,7 +5960,7 @@ def _(base, size): res_params)) except (AttributeError, AssertionError): # fallback for sfloat etc. - @library.for_range_opt(self.sizes[0]) + @library.for_range_opt_multithread(n_threads, self.sizes[0]) def _(i): try: res_matrix[i] = self.value_type.row_matrix_mul( diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index e426bade9..4657b5d88 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -28,6 +28,7 @@ class P256Element : public ValueInterface static const true_type invertible; static int size() { return 0; } + static int length() { return 256; } static string type_string() { return "P256"; } static void init(); diff --git a/ECDSA/semi-ecdsa-party.cpp b/ECDSA/semi-ecdsa-party.cpp index 6bdcec286..d7a4d8836 100644 --- a/ECDSA/semi-ecdsa-party.cpp +++ b/ECDSA/semi-ecdsa-party.cpp @@ -10,6 +10,7 @@ #include "Protocols/SemiPrep.hpp" #include "Protocols/SemiInput.hpp" #include "Protocols/MAC_Check_Base.hpp" +#include "GC/SemiSecret.hpp" #include "ot-ecdsa-party.hpp" #include diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 5f8e76fd3..de9e9cad4 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -49,8 +49,8 @@ class Client * @param n number of values * @returns vector of integer-like values */ - template - vector receive_outputs(int n); + template + vector receive_outputs(int n); }; #endif /* EXTERNALIO_CLIENT_H_ */ diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index 601d9a486..3af40f2f4 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -91,8 +91,8 @@ void Client::send_private_inputs(const vector& values) // Receive shares of the result and sum together. // Also receive authenticating values. -template -vector Client::receive_outputs(int n) +template +vector Client::receive_outputs(int n) { vector triples(3 * n); octetStream os; @@ -111,7 +111,7 @@ vector Client::receive_outputs(int n) } } - vector output_values; + vector output_values; for (int i = 0; i < 3 * n; i += 3) { if (T(triples[i] * triples[i + 1]) != triples[i + 2]) diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp index f68384c00..b040dd5e8 100644 --- a/ExternalIO/bankers-bonus-client.cpp +++ b/ExternalIO/bankers-bonus-client.cpp @@ -46,7 +46,7 @@ #include #include -template +template void one_run(T salary_value, Client& client) { // Run the computation @@ -54,18 +54,18 @@ void one_run(T salary_value, Client& client) cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; // Get the result back (client_id of winning client) - T result = client.receive_outputs(1)[0]; + U result = client.receive_outputs(1)[0]; cout << "Winning client id is : " << result << endl; } -template +template void run(double salary_value, Client& client) { // sint - one_run(long(round(salary_value)), client); + one_run(long(round(salary_value)), client); // sfix with f = 16 - one_run(long(round(salary_value * exp2(16))), client); + one_run(long(round(salary_value * exp2(16))), client); } int main(int argc, char** argv) @@ -125,7 +125,7 @@ int main(int argc, char** argv) { gfp::init_field(specification.get()); cerr << "using prime " << gfp::pr() << endl; - run(salary_value, client); + run(salary_value, client); break; } case 'R': @@ -134,13 +134,13 @@ int main(int argc, char** argv) switch (R) { case 64: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; case 104: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; case 128: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; default: cerr << R << "-bit ring not implemented"; diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 0de8bb1e9..5fb07f233 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -2,9 +2,11 @@ #include "FHE_Params.h" #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" +#include "Protocols/HemiOptions.h" +#include "Processor/OnlineOptions.h" -FHE_Params::FHE_Params(int n_mults) : - FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1) +FHE_Params::FHE_Params(int n_mults, int drown_sec) : + FFTData(n_mults + 1), Chi(0.7), sec_p(drown_sec), matrix_dim(1) { } @@ -17,16 +19,20 @@ void FHE_Params::set(const Ring& R, for (size_t i = 0; i < FFTData.size(); i++) FFTData[i].init(R,primes[i]); - set_sec(40); + set_sec(sec_p); } void FHE_Params::set_sec(int sec) { + assert(sec >= 0); sec_p=sec; Bval=1; Bval=Bval<matrix_dim = matrix_dim; } +void FHE_Params::set_matrix_dim_from_options() +{ + set_matrix_dim( + HemiOptions::singleton.plain_matmul ? + 1 : OnlineOptions::singleton.batch_size); +} + bigint FHE_Params::Q() const { bigint res = FFTData[0].get_prime(); diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 9407b0ba4..8821e2e29 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -13,6 +13,7 @@ #include "FHE/FFT_Data.h" #include "FHE/DiscreteGauss.h" #include "Tools/random.h" +#include "Protocols/config.h" class FHE_Params { @@ -30,15 +31,17 @@ class FHE_Params public: - FHE_Params(int n_mults = 1); + FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY); int n_mults() const { return FFTData.size() - 1; } void set(const Ring& R,const vector& primes); void set(const vector& primes); void set_sec(int sec); + void set_min_sec(int sec); void set_matrix_dim(int matrix_dim); + void set_matrix_dim_from_options(); int get_matrix_dim() const { return matrix_dim; } const vector& FFTD() const { return FFTData; } diff --git a/FHE/Matrix.cpp b/FHE/Matrix.cpp index c9c23aaba..dcec137e4 100644 --- a/FHE/Matrix.cpp +++ b/FHE/Matrix.cpp @@ -68,6 +68,10 @@ void HNF(matrix& H,matrix& U,const matrix& A) { int m=A.size(),n=A[0].size(),r,i,j,k; +#ifdef VERBOSE + cerr << "HNF m=" << m << ", n=" << n << endl; +#endif + H=A; ident(U,n); r=min(m,n); @@ -79,9 +83,9 @@ void HNF(matrix& H,matrix& U,const matrix& A) { if (step==2) { // Step 2 k=-1; - mn=bigint(1)<<256; + mn=bigint(0); for (j=i; j 0) @@ -428,6 +429,16 @@ GF2X Subs_PowX_Mod(const GF2X& a,int pow,int m,const GF2X& c) +GF2X get_F(const Ring& Rg) +{ + GF2X F; + for (int i=0; i<=Rg.phi_m(); i++) + { if (((Rg.Phi()[i])%2)!=0) + { SetCoeff(F,i,1); } + } + //cout << "F = " << F << endl; + return F; +} void init(P2Data& P2D,const Ring& Rg) { @@ -438,16 +449,12 @@ void init(P2Data& P2D,const Ring& Rg) { SetCoeff(G,gf2n_short::get_t(i),1); } //cout << "G = " << G << endl; - for (int i=0; i<=Rg.phi_m(); i++) - { if (((Rg.Phi()[i])%2)!=0) - { SetCoeff(F,i,1); } - } - //cout << "F = " << F << endl; + F = get_F(Rg); // seed randomness to achieve same result for all players // randomness is used in SFCanZass and FindRoot SetSeed(ZZ(0)); - + // Now factor F modulo 2 vec_GF2X facts=SFCanZass(F); @@ -459,17 +466,34 @@ void init(P2Data& P2D,const Ring& Rg) // Compute the quotient group QGroup QGrp; int Gord=-1,e=Rg.phi_m()/d; // e = # of plaintext slots, phi(m)/degree - int seed=1; - while (Gord!=e) + + if ((e*gf2n_short::degree())!=Rg.phi_m()) + { cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl; + cout << e << " * " << gf2n_short::degree() << " != " << Rg.phi_m() << endl; + throw invalid_params(); + } + + int max_tries = 10; + for (int seed = 0;; seed++) { QGrp.assign(Rg.m(),seed); // QGrp encodes the the quotient group Z_m^*/<2> - Gord=QGrp.order(); - if (Gord!=e) { cout << "Group order wrong, need to repeat the Haf-Mc algorithm" << endl; seed++; } + Gord = QGrp.order(); + if (Gord == e) + { + break; + } + else + { + if (seed == max_tries) + { + cerr << "abort after " << max_tries << " tries" << endl; + throw invalid_params(); + } + else + cout << "Group order wrong, need to repeat the Haf-Mc algorithm" + << endl; + } } //cout << " l = " << Gord << " , d = " << d << endl; - if ((Gord*gf2n_short::degree())!=Rg.phi_m()) - { cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl; - throw not_implemented(); - } vector Fi(Gord); vector Rts(Gord); @@ -590,8 +614,23 @@ void char_2_dimension(int& m, int& lg2) m=5797; lg2=40; break; + case 64: + m = 9615; + break; + case 63: + m = 9271; + break; + case 28: + m = 3277; + break; case 16: - m = 13107; + m = 4369; + break; + case 12: + m = 4095; + break; + case 11: + m = 2047; break; default: throw runtime_error("field size not supported"); @@ -628,7 +667,7 @@ void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D) finalize_lengths(lg2p0, lg2p1, n, m, lg2pi[0], round_up, params); } - if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) > phi_N(m)) + if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) * 2 > m) throw runtime_error("number of slots too small"); cout << "m = " << m << endl; diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index acaba70b6..7d7d13de6 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -55,12 +55,19 @@ int generate_semi_setup(int plaintext_length, int sec, FHE_Params& params, FD& FieldD, bool round_up, int n = 1); // field-independent semi-homomorphic setup -int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, +int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up); void init(Ring& Rg, int m, bool generate_poly); void init(P2Data& P2D,const Ring& Rg); +namespace NTL +{ +class GF2X; +} + +NTL::GF2X get_F(const Ring& Rg); + // For use when we want p to be a specific value void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0, bigint& pr1, int n, int sec, bigint& p, FHE_Params& params); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index f2e151c42..a1fe3e033 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -36,11 +36,12 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s); // unify parameters by taking maximum over TopGear or not bigint B_clean_top_gear = B_clean * 2; - bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); + bigint B_clean_not_top_gear = B_clean << max(slack - sec, 0); B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); int matrix_dim = params.get_matrix_dim(); #ifdef NOISY + cout << "phi(m): " << phi_m << endl; cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; cout << "V_s: " << V_s << endl; cout << "c1: " << c1 << endl; @@ -50,10 +51,13 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, cout << "B_clean: " << B_clean << endl; cout << "B_scale: " << B_scale << endl; cout << "matrix dimension: " << matrix_dim << endl; + cout << "drown sec: " << params.secp() << endl; + cout << "sec: " << sec << endl; #endif assert(matrix_dim > 0); - drown = 1 + matrix_dim * n * (bigint(1) << sec); + assert(params.secp() >= 0); + drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp()); } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) @@ -71,8 +75,14 @@ double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma) { if (sigma <= 0) sigma = FHE_Params().get_R(); - // the constant was updated using Martin Albrecht's LWE estimator in Sep 2019 - return 37.8 * (log_q - log2(sigma)); + // the constant was updated using Martin Albrecht's LWE estimator in Mar 2022 + // found the following pairs for 128-bit security + // and alpha = 0.7 * sqrt(2*pi) / q + // m = 2048, log_2(q) = 68 + // m = 4096, log_2(q) = 138 + // m = 8192, log_2(q) = 302 + // m = 16384, log_2(q) = 560 + return 15.1 * log_q; } double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, const FHE_Params& params) diff --git a/FHE/Subroutines.cpp b/FHE/Subroutines.cpp index 8ec6a7cee..a688b1691 100644 --- a/FHE/Subroutines.cpp +++ b/FHE/Subroutines.cpp @@ -11,35 +11,15 @@ void Subs(modp& ans,const vector& poly,const modp& x,const Zp_Data& ZpD) assignZero(ans,ZpD); for (int i=poly.size()-1; i>=0; i--) { Mul(ans,ans,x,ZpD); - switch (poly[i]) - { case 0: - break; - case 1: - Add(ans,ans,one,ZpD); - break; - case -1: - Sub(ans,ans,one,ZpD); - break; - case 2: - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - break; - case -2: - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - break; - case 3: - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - break; - case -3: - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - break; - default: - throw not_implemented(); + if (poly[i] > 0) + { + for (int j = 0; j < poly[i]; j++) + Add(ans, ans, one, ZpD); + } + if (poly[i] < 0) + { + for (int j = 0; j < -poly[i]; j++) + Sub(ans, ans, one, ZpD); } } } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 48a8a6ef8..0dc8d9675 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -40,10 +40,9 @@ template void PartSetup::generate_setup(int n_parties, int plaintext_length, int sec, int slack, bool round_up) { - sec = max(sec, 40); + params.set_min_sec(sec); Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup( params, FieldD); - params.set_sec(sec); pk = FHE_PK(params, FieldD.get_prime()); sk = FHE_SK(params, FieldD.get_prime()); calpha = Ciphertext(params); @@ -180,11 +179,8 @@ void PartSetup::init_field() } template -void PartSetup::check(int sec) const +void PartSetup::check() const { - sec = max(sec, 40); - if (abs(sec - params.secp()) > 2) - throw runtime_error("security parameters vary too much between protocol and distributed decryption"); sk.check(params, pk, FieldD.get_prime()); } diff --git a/FHEOffline/DataSetup.h b/FHEOffline/DataSetup.h index 88c9e05fc..1c606506d 100644 --- a/FHEOffline/DataSetup.h +++ b/FHEOffline/DataSetup.h @@ -57,7 +57,7 @@ class PartSetup void init_field(); - void check(int sec) const; + void check() const; bool operator!=(const PartSetup& other); void secure_init(Player& P, MachineBase& machine, int plaintext_length, diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 92632002c..43ad7e842 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -69,7 +69,7 @@ void Multiplier::add(Plaintext_& res, const Ciphertext& c, product_share.randomize(G); bigint B = 6 * machine.setup().params.get_R(); B *= machine.setup().FieldD.get_prime(); - B <<= machine.drown_sec; + B <<= machine.setup().params.secp(); // slack B *= NonInteractiveProof::slack(machine.sec, machine.setup().params.phi_m()); diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index e41fe1837..b19dd62cf 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -29,7 +29,7 @@ void PairwiseMachine::init() { if (use_gf2n) { - field_size = 40; + field_size = gf2n_short::DEFAULT_LENGTH; gf2n_short::init_field(field_size); setup_keys(); } @@ -67,7 +67,7 @@ void PairwiseMachine::setup_keys() { auto& N = P; PairwiseSetup& s = setup(); - s.init(P, drown_sec, field_size, extra_slack); + s.init(P, sec, field_size, extra_slack); if (output) write_mac_key(get_prep_dir(P), P.my_num(), P.num_players(), s.alphai); for (auto& x : other_pks) diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 047c84f2c..59223ad03 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -65,11 +65,12 @@ template void secure_init(T& setup, Player& P, U& machine, int plaintext_length, int sec, FHE_Params& params) { + assert(sec >= 0); machine.sec = sec; - sec = max(sec, 40); - machine.drown_sec = sec; + params.set_min_sec(sec); string filename = PREP_DIR + T::name() + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-" + + to_string(params.secp()) + "-" + to_string(params.get_matrix_dim()) + "-" + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h index 5e690b67c..6059ef3bc 100644 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -78,6 +78,7 @@ class Proof diagonal(diagonal), B_plain_length(0), B_rand_length(0), pk(&pk), n_proofs(n_proofs) { sec=sc; + assert(sec > 0); tau=Tau; rho=Rho; phim=(pk.get_params()).phi_m(); diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp index dab3e507a..63d559d53 100644 --- a/FHEOffline/Sacrificing.cpp +++ b/FHEOffline/Sacrificing.cpp @@ -10,6 +10,8 @@ #include "Tools/Subroutines.h" +#include "Protocols/mac_key.hpp" + // The number of sacrifices to amortize at one time #define amortize 512 @@ -19,12 +21,7 @@ void Triple_Checking(const Player& P, MAC_Check& MC, int nm, int output_thread, TripleSacriFactory< Share >& factory, bool write_output, bool clear, string dir) { - if (T::length() < 40) - { - cerr << "Field too small for reasonable security" << endl; - cerr << "Use a larger field or remove this warning from " << __FILE__ << endl; - exit(1); - } + check_field_size(); ofstream outf; if (write_output) diff --git a/FHEOffline/SimpleMachine.cpp b/FHEOffline/SimpleMachine.cpp index 04d190f1c..a7be00a42 100644 --- a/FHEOffline/SimpleMachine.cpp +++ b/FHEOffline/SimpleMachine.cpp @@ -27,7 +27,7 @@ void* run_generator(void* generator) MachineBase::MachineBase() : throughput_loop_thread(0),portnum_base(0), data_type(DATA_TRIPLE), - sec(0), drown_sec(0), field_size(0), extra_slack(0), + sec(0), field_size(0), extra_slack(0), produce_inputs(false), use_gf2n(false) { @@ -91,7 +91,6 @@ void MachineBase::parse_options(int argc, const char** argv) opt.get("-h")->getString(hostname); opt.get("-pn")->getInt(portnum_base); opt.get("-s")->getInt(sec); - drown_sec = max(40, sec); opt.get("-f")->getInt(field_size); use_gf2n = opt.isSet("-2"); if (use_gf2n) @@ -221,7 +220,7 @@ void MultiplicativeMachine::fake_keys(int slack) PartSetup& part_setup = setup.part(); if (P.my_num() == 0) { - part_setup.generate_setup(N.num_players(), field_size, drown_sec, slack, true); + part_setup.generate_setup(N.num_players(), field_size, sec, slack, true); vector > setups; part_setup.fake(setups, P.num_players(), false); for (int i = 1; i < P.num_players(); i++) @@ -238,7 +237,7 @@ void MultiplicativeMachine::fake_keys(int slack) P.receive_player(0, os); } part_setup.unpack(os); - part_setup.check(drown_sec); + part_setup.check(); part_setup.alphai = read_or_generate_mac_key>(P); Plaintext_ m(part_setup.FieldD); diff --git a/FHEOffline/SimpleMachine.h b/FHEOffline/SimpleMachine.h index 8ca37bbfe..e8d071700 100644 --- a/FHEOffline/SimpleMachine.h +++ b/FHEOffline/SimpleMachine.h @@ -26,7 +26,6 @@ class MachineBase : public OfflineMachineBase public: int sec; - int drown_sec; int field_size; int extra_slack; bool produce_inputs; diff --git a/FHEOffline/TemiSetup.cpp b/FHEOffline/TemiSetup.cpp index fc222ed51..fd922d4cf 100644 --- a/FHEOffline/TemiSetup.cpp +++ b/FHEOffline/TemiSetup.cpp @@ -15,9 +15,7 @@ TemiSetup::TemiSetup() this->pk = {this->params, 0}; this->sk = {this->params, 0}; this->calpha = this->params; - this->params.set_matrix_dim( - HemiOptions::singleton.plain_matmul ? - 1 : OnlineOptions::singleton.batch_size); + this->params.set_matrix_dim_from_options(); } template diff --git a/GC/CcdShare.h b/GC/CcdShare.h index e890ce633..894d3ae74 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -40,7 +40,7 @@ class CcdShare : public ShamirShare, public ShareSecret> return "CCD"; } - static MAC_Check* new_mc(T) + static MAC_Check* new_mc(typename super::mac_key_type) { return new MAC_Check; } diff --git a/GC/DealerPrep.h b/GC/DealerPrep.h new file mode 100644 index 000000000..a3bd4bcc8 --- /dev/null +++ b/GC/DealerPrep.h @@ -0,0 +1,69 @@ +/* + * DealerPrep.h + * + */ + +#ifndef GC_DEALERPREP_H_ +#define GC_DEALERPREP_H_ + +#include "Protocols/DealerPrep.h" +#include "Protocols/ProtocolSet.h" +#include "ShiftableTripleBuffer.h" +#include "SemiSecret.h" + +namespace GC +{ +class DealerPrep : public BufferPrep, ShiftableTripleBuffer +{ + Player* P; + +public: + DealerPrep(DataPositions& usage, int = -1) : + BufferPrep(usage), P(0) + { + } + + void set_protocol(DealerSecret::Protocol& protocol) + { + P = &protocol.P; + } + + void buffer_triples() + { + ProtocolSetup> setup(*P); + ProtocolSet> set(*P, setup); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + auto triple = set.preprocessing.get_triple( + DealerSecret::default_length); + this->triples.push_back({{triple[0], triple[1], triple[2]}}); + } + } + + void buffer_bits() + { + SeededPRNG G; + if (P->my_num() != 0) + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->bits.push_back(G.get_bit()); + else + this->bits.resize( + this->bits.size() + OnlineOptions::singleton.batch_size); + } + + void get(Dtype type, DealerSecret* data) + { + BufferPrep::get(type, data); + } + + array get_triple_no_count(int n_bits) + { + if (n_bits == -1) + n_bits = DealerSecret::default_length; + return ShiftableTripleBuffer::get_triple_no_count(n_bits); + } +}; + +} + +#endif /* GC_DEALERPREP_H_ */ diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 00e6c52c9..ee7a84462 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -10,6 +10,7 @@ #include "GC/Memory.h" #include "GC/Access.h" #include "GC/ArgTuples.h" +#include "GC/NoShare.h" #include "Math/gf2nlong.h" #include "Tools/SwitchableOutput.h" @@ -40,7 +41,6 @@ class FakeSecret : public ShareInterface, public BitVec typedef FakeSecret DynamicType; typedef Memory DynamicMemory; - typedef BitVec mac_key_type; typedef BitVec clear; typedef BitVec open_type; diff --git a/GC/MaliciousCcdShare.h b/GC/MaliciousCcdShare.h index 9dc63fc63..fbc66ea1a 100644 --- a/GC/MaliciousCcdShare.h +++ b/GC/MaliciousCcdShare.h @@ -44,7 +44,7 @@ class MaliciousCcdShare: public MaliciousShamirShare, public ShareSecret< return "Malicious CCD"; } - static MAC_Check* new_mc(T) + static MAC_Check* new_mc(typename super::mac_key_type) { return new MAC_Check; } diff --git a/GC/MaliciousRepSecret.h b/GC/MaliciousRepSecret.h index 500bbb5af..9f941d51d 100644 --- a/GC/MaliciousRepSecret.h +++ b/GC/MaliciousRepSecret.h @@ -30,7 +30,7 @@ class SmallMalRepSecret : public FixedVec, 2> typedef MaliciousRepMC MC; typedef BitVec_ open_type; typedef open_type clear; - typedef BitVec mac_key_type; + typedef NoValue mac_key_type; static MC* new_mc(mac_key_type) { @@ -71,7 +71,7 @@ class MalRepSecretBase : public ReplicatedSecret static const bool expensive_triples = true; - static MC* new_mc(BitVec) + static MC* new_mc(typename super::mac_key_type) { try { diff --git a/GC/NoShare.h b/GC/NoShare.h index c435ec3f8..49f93ac42 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -60,35 +60,43 @@ class NoValue : public ValueInterface throw not_implemented(); } + static void init_minimum(int) + { + } + static void fail() { throw runtime_error("VM does not support binary circuits"); } NoValue() {} - NoValue(int) { fail(); } + NoValue(bool) {} + NoValue(ValueInterface) {} + NoValue(int128) {} void assign(const char*) { fail(); } + const char* get_ptr() const { return (char*) this; } + int get() const { fail(); return 0; } int operator<<(int) const { fail(); return 0; } void operator+=(int) { fail(); } - bool operator!=(NoValue) const { fail(); return 0; } + bool operator!=(NoValue) const { return false; } bool operator==(int) { fail(); return false; } bool get_bit(int) { fail(); return 0; } - void randomize(PRNG&) { fail(); } + void randomize(PRNG&) {} void invert() { fail(); } void mask(int) { fail(); } void input(istream&, bool) { fail(); } - void output(ostream&, bool) { fail(); } + void output(ostream&, bool) {} }; inline ostream& operator<<(ostream& o, NoValue) diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 663d55fcb..96b2d62d8 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -340,7 +340,7 @@ void Processor::convcbit2s(const BaseInstruction& instruction) for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) S[instruction.get_r(0) + i] = T::constant(C[instruction.get_r(1) + i], share_thread.P->my_num(), share_thread.MC->get_alphai(), - min(unsigned(unit), instruction.get_n() - i * unit)); + min(size_t(unit), instruction.get_n() - i * unit)); } template diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index ae10b5222..e95554bf5 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -8,6 +8,7 @@ #include "Protocols/SemiMC.h" #include "Protocols/SemiShare.h" +#include "Protocols/DealerShare.h" #include "Processor/DummyProtocol.h" #include "ShareSecret.h" @@ -17,71 +18,116 @@ namespace GC { class SemiPrep; +class DealerPrep; -class SemiSecret : public SemiShare, public ShareSecret +template +class SemiSecretBase : public V, public ShareSecret { + typedef V super; + public: - typedef Memory DynamicMemory; + typedef Memory DynamicMemory; - typedef SemiMC MC; - typedef DirectSemiMC Direct_MC; - typedef Beaver Protocol; - typedef MC MAC_Check; - typedef SemiPrep LivePrep; - typedef SemiInput Input; + typedef Beaver Protocol; - typedef SemiSecret part_type; - typedef SemiSecret small_type; + typedef T part_type; + typedef T small_type; static const int default_length = sizeof(BitVec) * 8; static string type_string() { return "binary secret"; } static string phase_name() { return "Binary computation"; } - static MC* new_mc(mac_key_type); - - template - static void generate_mac_key(mac_key_type, T) - { - } - - static void trans(Processor& processor, int n_outputs, + static void trans(Processor& processor, int n_outputs, const vector& args); - SemiSecret() + SemiSecretBase() { } - SemiSecret(long other) : - SemiShare(other) + SemiSecretBase(long other) : + V(other) { } - SemiSecret(const IntBase& other) : - SemiShare(other) + template + SemiSecretBase(const IntBase& other) : + V(other) { } template - SemiSecret(const Z2& other) : - SemiShare(other) + SemiSecretBase(const Z2& other) : + V(other) { } void load_clear(int n, const Integer& x); - void bitcom(Memory& S, const vector& regs); - void bitdec(Memory& S, const vector& regs) const; + void bitcom(Memory& S, const vector& regs); + void bitdec(Memory& S, const vector& regs) const; - void xor_(int n, const SemiSecret& x, const SemiSecret& y) + void xor_(int n, const T& x, const T& y) { *this = BitVec(x ^ y).mask(n); } - void xor_bit(int i, const SemiSecret& bit) + void xor_bit(int i, const T& bit) { *this ^= bit << i; } void reveal(size_t n_bits, Clear& x); - SemiSecret lsb() + T lsb() { return *this & 1; } }; +class SemiSecret: public SemiSecretBase> +{ + typedef SemiSecret This; + +public: + typedef SemiSecretBase> super; + + typedef SemiMC MC; + typedef DirectSemiMC Direct_MC; + typedef MC MAC_Check; + typedef SemiInput Input; + typedef SemiPrep LivePrep; + + static MC* new_mc(typename SemiShare::mac_key_type); + + SemiSecret() + { + } + + template + SemiSecret(const T& other) : + super(other) + { + } +}; + +class DealerSecret : public SemiSecretBase> +{ + typedef DealerSecret This; + +public: + typedef SemiSecretBase> super; + + typedef DealerMC MC; + typedef DirectDealerMC Direct_MC; + typedef MC MAC_Check; + typedef DealerInput Input; + typedef DealerPrep LivePrep; + + static MC* new_mc(typename super::mac_key_type); + + DealerSecret() + { + } + + template + DealerSecret(const T& other) : + super(other) + { + } +}; + } /* namespace GC */ #endif /* GC_SEMISECRET_H_ */ diff --git a/GC/SemiSecret.cpp b/GC/SemiSecret.hpp similarity index 61% rename from GC/SemiSecret.cpp rename to GC/SemiSecret.hpp index 704e2a2fb..f6a4d3984 100644 --- a/GC/SemiSecret.cpp +++ b/GC/SemiSecret.hpp @@ -4,17 +4,30 @@ */ #include "GC/ShareParty.h" -#include "SemiSecret.h" - #include "GC/ShareSecret.hpp" #include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/DealerMC.h" +#include "SemiSecret.h" namespace GC { -const int SemiSecret::default_length; +template +const int SemiSecretBase::default_length; + +inline +SemiSecret::MC* SemiSecret::new_mc( + typename super::mac_key_type) +{ + if (OnlineOptions::singleton.direct) + return new Direct_MC; + else + return new MC; +} -SemiSecret::MC* SemiSecret::new_mc(mac_key_type) +inline +DealerSecret::MC* DealerSecret::new_mc( + typename super::mac_key_type) { if (OnlineOptions::singleton.direct) return new Direct_MC; @@ -22,7 +35,8 @@ SemiSecret::MC* SemiSecret::new_mc(mac_key_type) return new MC; } -void SemiSecret::trans(Processor& processor, int n_outputs, +template +void SemiSecretBase::trans(Processor& processor, int n_outputs, const vector& args) { int N_BITS = default_length; @@ -46,29 +60,33 @@ void SemiSecret::trans(Processor& processor, int n_outputs, } } -void SemiSecret::load_clear(int n, const Integer& x) +template +void SemiSecretBase::load_clear(int n, const Integer& x) { - check_length(n, x); - *this = constant(x, ShareThread::s().P->my_num()); + this->check_length(n, x); + *this = this->constant(x, ShareThread::s().P->my_num()); } -void SemiSecret::bitcom(Memory& S, const vector& regs) +template +void SemiSecretBase::bitcom(Memory& S, const vector& regs) { *this = 0; for (unsigned int i = 0; i < regs.size(); i++) *this ^= (S[regs[i]] << i); } -void SemiSecret::bitdec(Memory& S, +template +void SemiSecretBase::bitdec(Memory& S, const vector& regs) const { for (unsigned int i = 0; i < regs.size(); i++) S[regs[i]] = (*this >> i) & 1; } -void SemiSecret::reveal(size_t n_bits, Clear& x) +template +void SemiSecretBase::reveal(size_t n_bits, Clear& x) { - auto& thread = ShareThread::s(); + auto& thread = ShareThread::s(); x = thread.MC->POpen(*this, *thread.P).mask(n_bits); } diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 6d9f26525..fb254486b 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -112,8 +112,8 @@ class RepSecretBase : public FixedVec, public ShareSecret typedef BitVec clear; typedef BitVec open_type; - typedef BitVec mac_type; - typedef BitVec mac_key_type; + typedef NoShare mac_type; + typedef NoValue mac_key_type; typedef NoShare bit_type; @@ -213,7 +213,7 @@ class SmallRepSecret : public FixedVec, 2> typedef ReplicatedMC MC; typedef BitVec_ open_type; typedef open_type clear; - typedef BitVec mac_key_type; + typedef NoValue mac_key_type; static MC* new_mc(mac_key_type) { diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index c6c9dcaac..a426eea29 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -11,6 +11,8 @@ #include "instructions.h" +#include "Tools/benchmarking.h" + #include "Machine.hpp" namespace GC @@ -58,15 +60,10 @@ Thread* ThreadMaster::new_thread(int i) template void ThreadMaster::run() { -#ifndef INSECURE if (not opts.live_prep) { - cerr - << "Preprocessing from file not supported by binary virtual machines" - << endl; - exit(1); + insecure("preprocessing from file in binary virtual machines"); } -#endif P = new PlainPlayer(N, "main"); diff --git a/GC/TinyMC.h b/GC/TinyMC.h index e0a0b948b..ac3a29ab2 100644 --- a/GC/TinyMC.h +++ b/GC/TinyMC.h @@ -48,12 +48,12 @@ class TinyMC : public MAC_Check_Base part_MC.exchange(P); } - typename T::open_type finalize_open() + typename T::open_type finalize_raw() { int n = sizes.next(); typename T::open_type opened = 0; for (int i = 0; i < n; i++) - opened += typename T::open_type(part_MC.finalize_open().get_bit(0)) << i; + opened += typename T::open_type(part_MC.finalize_raw().get_bit(0)) << i; return opened; } diff --git a/Machines/SPDZ2k.hpp b/Machines/SPDZ2k.hpp index 6cb02779d..04508c09a 100644 --- a/Machines/SPDZ2k.hpp +++ b/Machines/SPDZ2k.hpp @@ -5,6 +5,7 @@ #include "Protocols/Spdz2kShare.h" #include "Protocols/Spdz2kPrep.h" +#include "Protocols/SPDZ2k.h" #include "GC/TinySecret.h" #include "GC/TinyMC.h" diff --git a/Machines/Semi.hpp b/Machines/Semi.hpp index 1a0931467..02686df79 100644 --- a/Machines/Semi.hpp +++ b/Machines/Semi.hpp @@ -19,3 +19,4 @@ #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "GC/SemiSecret.hpp" diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 9f18d3a6f..d8cc30140 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -97,6 +97,5 @@ ShamirMachineSpec::ShamirMachineSpec(int argc, const char** argv) auto& opts = ShamirOptions::singleton; ez::ezOptionParser opt; opts = {opt, argc, argv}; - T::bit_type::part_type::open_type::init_field(); HonestMajorityFieldMachine(argc, argv, opt, opts.nparties); } diff --git a/Machines/TripleMachine.cpp b/Machines/TripleMachine.cpp index 82cde5e8d..e4a876fe2 100644 --- a/Machines/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -142,8 +142,8 @@ TripleMachine::TripleMachine(int argc, const char** argv) : gfpvar1::init_field(prime, false); else gfpvar1::init_default(128, false); - gf2n_long::init_field(128); - gf2n_short::init_field(40); + gf2n_long::init_field(); + gf2n_short::init_field(); PRNG G; G.ReSeed(); diff --git a/Machines/dealer-ring-party.cpp b/Machines/dealer-ring-party.cpp new file mode 100644 index 000000000..4bc8fab1a --- /dev/null +++ b/Machines/dealer-ring-party.cpp @@ -0,0 +1,22 @@ +/* + * dealer-ring-party.cpp + * + */ + +#include "Protocols/DealerShare.h" +#include "Protocols/DealerInput.h" + +#include "Processor/RingMachine.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/DealerPrep.hpp" +#include "Protocols/DealerInput.hpp" +#include "Protocols/DealerMC.hpp" +#include "Protocols/Beaver.hpp" +#include "Semi.hpp" +#include "GC/DealerPrep.h" + +int main(int argc, const char** argv) +{ + HonestMajorityRingMachine(argc, argv, 0); +} diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index f26f5f324..5999050c2 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -54,9 +54,8 @@ int main(int argc, const char** argv) { #define X(L) \ case L: \ - Machine>, FakeShare>(0, N, progname, \ - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, false, \ - online_opts.live_prep, online_opts).run(); \ + Machine>, FakeShare>(N, false, online_opts, \ + gf2n::default_degree()).run(progname); \ break; X(64) X(128) X(256) X(192) X(384) X(512) #ifdef RING_SIZE diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp index 934c15dcd..60e0d6e4f 100644 --- a/Machines/hemi-party.cpp +++ b/Machines/hemi-party.cpp @@ -27,6 +27,7 @@ #include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" +#include "GC/SemiSecret.hpp" #include "Math/gfp.hpp" int main(int argc, const char** argv) diff --git a/Machines/malicious-ccd-party.cpp b/Machines/malicious-ccd-party.cpp index 4ce84aea3..55ec8b99b 100644 --- a/Machines/malicious-ccd-party.cpp +++ b/Machines/malicious-ccd-party.cpp @@ -18,8 +18,9 @@ int main(int argc, const char** argv) { - gf2n_short::init_field(40); ez::ezOptionParser opt; ShamirOptions::singleton = {opt, argc, argv}; + OnlineOptions opts(opt, argc, argv); + gf2n_short::init_minimum(opts.security_parameter); GC::ShareParty>(argc, argv, opt); } diff --git a/Machines/semi-bin-party.cpp b/Machines/semi-bin-party.cpp index fbd0a6345..6c99ebf9c 100644 --- a/Machines/semi-bin-party.cpp +++ b/Machines/semi-bin-party.cpp @@ -14,6 +14,7 @@ #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" #include "GC/Processor.hpp" +#include "GC/SemiSecret.hpp" #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/SemiInput.hpp" diff --git a/Machines/soho-party.cpp b/Machines/soho-party.cpp index 7ecc450da..ced64919f 100644 --- a/Machines/soho-party.cpp +++ b/Machines/soho-party.cpp @@ -25,6 +25,7 @@ #include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" +#include "GC/SemiSecret.hpp" #include "Math/gfp.hpp" int main(int argc, const char** argv) diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index 8aba31737..9e8267376 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -22,12 +22,12 @@ int main(int argc, const char** argv) 1, // Number of args expected. 0, // Delimiter if expecting multiple args. "SPDZ2k security parameter (default: 64)", // Help description. - "-S", // Flag token. - "--security" // Flag token. + "-SP", // Flag token. + "--spdz2k-security" // Flag token. ); opt.parse(argc, argv); int s; - opt.get("-S")->getInt(s); + opt.get("-SP")->getInt(s); opt.resetArgs(); RingOptions ring_options(opt, argc, argv); int k = ring_options.R; @@ -62,6 +62,8 @@ int main(int argc, const char** argv) cerr << "add Z(" << k << ", " << s << ") to " << __FILE__ << " at line " << (__LINE__ - 11) << " and create Machines/SPDZ2^" << k << "+" << s << ".cpp based on Machines/SPDZ2^72+64.cpp" << endl; + cerr << "Alternatively, compile with -DRING_SIZE=" << k + << " and -DSPDZ2K_DEFAULT_SECURITY=" << s << endl; } exit(1); } diff --git a/Machines/temi-party.cpp b/Machines/temi-party.cpp index 12e99dc27..f8abd35d8 100644 --- a/Machines/temi-party.cpp +++ b/Machines/temi-party.cpp @@ -26,6 +26,7 @@ #include "Protocols/Hemi.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" +#include "GC/SemiSecret.hpp" #include "Math/gfp.hpp" int main(int argc, const char** argv) diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp index 82122bd11..67234a8a5 100644 --- a/Machines/tinier-party.cpp +++ b/Machines/tinier-party.cpp @@ -28,6 +28,8 @@ int main(int argc, const char** argv) { - gf2n_short::init_field(40); + ez::ezOptionParser opt; + OnlineOptions opts(opt, argc, argv); + gf2n_short::init_minimum(opts.security_parameter); GC::simple_binary_main>(argc, argv, 1000); } diff --git a/Machines/tiny-party.cpp b/Machines/tiny-party.cpp index 7e72361eb..f83f839f5 100644 --- a/Machines/tiny-party.cpp +++ b/Machines/tiny-party.cpp @@ -29,5 +29,5 @@ int main(int argc, const char** argv) { - GC::simple_binary_main>(argc, argv, 1000); + GC::simple_binary_main>(argc, argv, 1000); } diff --git a/Makefile b/Makefile index 4f558e1d6..3c2be0090 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR) -GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +GC_SEMI = GC/SemiPrep.o GC/square64.o OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) OT_EXE = ot.x ot-offline.x @@ -57,7 +57,7 @@ vm: arithmetic binary doc: cd doc; $(MAKE) html -arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy +arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy dealer-ring-party.x binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr all: overdrive she-offline @@ -162,7 +162,7 @@ bmr-%.x: $(BMR) $(VM) Machines/bmr-%.cpp $(LIBSIMPLEOT) bmr-clean: -rm BMR/*.o BMR/*/*.o GC/*.o -bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) +bankers-bonus-client.x: ExternalIO/bankers-bonus-client.o $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) simple-offline.x: $(FHEOFFLINE) @@ -203,13 +203,13 @@ replicated-field-party.x: GC/square64.o brain-party.x: GC/square64.o malicious-rep-bin-party.x: GC/square64.o ps-rep-bin-party.x: GC/PostSacriBin.o -semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +semi-bin-party.x: $(OT) GC/SemiPrep.o GC/square64.o tiny-party.x: $(OT) tinier-party.x: $(OT) spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) -semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o -semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +semi-party.x: $(OT) GC/SemiPrep.o GC/square64.o +semi2k-party.x: $(OT) GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) @@ -234,15 +234,16 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o rep4-ring-party.x: GC/Rep4Secret.o no-party.x: Protocols/ShareInterface.o -semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o +semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) emulate.x: GC/FakeSecret.o -semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(OT) +semi-bmr-party.x: GC/SemiPrep.o $(OT) real-bmr-party.x: $(OT) paper-example.x: $(VM) $(OT) $(FHEOFFLINE) -binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o -mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o Machines/Tinier.o +binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o +mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o +l2h-example.x: $(VM) $(OT) Machines/Tinier.o mascot-offline.x: $(VM) $(TINIER) cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) @@ -253,6 +254,7 @@ static/semi-bmr-party.x: $(BMR) static/real-bmr-party.x: $(BMR) static/bmr-program-party.x: $(BMR) static/no-party.x: Protocols/ShareInterface.o +Test/failure.x: Protocols/MalRepRingOptions.o ifeq ($(AVX_OT), 1) $(LIBSIMPLEOT): SimpleOT/Makefile @@ -270,7 +272,7 @@ Programs/Circuits: .PHONY: mpir-setup mpir-global mpir mpir-setup: - git submodule update --init mpir + git submodule update --init mpir || git clone https://github.com/wbhart/mpir cd mpir; \ autoreconf -i; \ autoreconf -i @@ -306,7 +308,7 @@ linux-machine-setup: endif simde/simde: - git submodule update --init simde + git submodule update --init simde || git clone https://github.com/simd-everywhere/simde clean: -rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so diff --git a/Math/Integer.h b/Math/Integer.h index 8104724c0..1fbb257fc 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -37,8 +37,6 @@ class IntBase : public ValueInterface static void specification(octetStream& os); - static void init_default(int lgp) { (void)lgp; } - static bool allows(Dtype type) { return type <= DATA_BIT; } IntBase() { a = 0; } diff --git a/Math/Setup.cpp b/Math/Setup.cpp index b4800017c..dc76e47d7 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -27,6 +27,13 @@ void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m) cerr << "Setting up parameters" << endl; #endif + m = default_m(lgp, idx); + generate_prime(p, lgp, m); +} + +int default_m(int& lgp, int& idx) +{ + int m; switch (lgp) { case -1: m=16; @@ -56,15 +63,12 @@ void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m) default: m=1; idx=0; -#ifdef VERBOSE - cerr << "no precomputed parameters, trying anyway" << endl; -#endif break; } #ifdef VERBOSE cerr << "m = " << m << endl; #endif - generate_prime(p, lgp, m); + return m; } bigint generate_prime(int lgp, int m) @@ -95,6 +99,9 @@ void generate_prime(bigint& p, int lgp, int m) return; } + int idx; + m = max(m, default_m(lgp, idx)); + bigint u; int ex; ex = lgp - numBits(m); diff --git a/Math/Setup.h b/Math/Setup.h index f8405ba3b..8c599198e 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -35,6 +35,7 @@ bigint SPDZ_Data_Setup_Primes(int lgp); void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m); void generate_prime(bigint& p, int lgp, int m); bigint generate_prime(int lgp, int m); +int default_m(int& lgp, int& idx); string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, const string& type_short); diff --git a/Math/bigint.h b/Math/bigint.h index 5cd319817..cb79f2424 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -12,6 +12,7 @@ using namespace std; #include "Tools/random.h" #include "Tools/octetStream.h" #include "Tools/avx_memcpy.h" +#include "Protocols/config.h" enum ReportType { @@ -270,7 +271,8 @@ inline int probPrime(const bigint& x) { gmp_randstate_t rand_state; gmp_randinit_default(rand_state); - int ans=mpz_probable_prime_p(x.get_mpz_t(),rand_state,40,0); + int ans = mpz_probable_prime_p(x.get_mpz_t(), rand_state, + max(40, DEFAULT_SECURITY), 0); gmp_randclear(rand_state); return ans; } diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index f9491fb7b..44e424794 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -18,7 +18,7 @@ bool gf2n_::useC; word gf2n_short_table[256][256]; -#define num_2_fields 7 +#define num_2_fields 17 /* Require * 2*(n-1)-64+t1<64 @@ -26,11 +26,21 @@ word gf2n_short_table[256][256]; int fields_2[num_2_fields][4] = { { 4, 1, 0, 0 }, + { 5, 2, 0, 0 }, + { 6, 1, 0, 0 }, + { 7, 1, 0, 0 }, { 8, 4, 3, 1 }, + { 9, 1, 0, 0 }, + { 10, 3, 0, 0}, + { 11, 2, 0, 0}, + { 12, 3, 0, 0}, + { 14, 5, 0, 0}, + { 15, 1, 0, 0}, { 16, 5, 3, 1 }, { 28, 1, 0, 0 }, { 40, 20, 15, 10 }, { 63, 1, 0, 0 }, + { 64, 4, 3, 1}, { 128, 7, 2, 1 }, }; @@ -55,6 +65,21 @@ void gf2n_::init_tables() } } +template +void gf2n_::init_minimum(int lower) +{ + if (lower <= n) + return; + + for (int i = 0; i < num_2_fields; i++) + { + int n = fields_2[i][0]; + if (lower <= n and n <= MAX_N_BITS) + return init_field(n); + } + throw runtime_error("no suitable field for minimum degree " + to_string(lower)); +} + void gf2n_short::init_field(int nn) { super::init_field(nn == 0 ? DEFAULT_LENGTH : nn); @@ -88,7 +113,7 @@ void gf2n_::init_field(int nn) if (j==-1) { - throw runtime_error("field size not supported"); + throw gf2n_not_supported(nn); } n=nn; @@ -332,7 +357,11 @@ gf2n_ gf2n_::invert() const if (n < 64) return U(invert(a)); else - return invert>(a).get_lower(); + { + gf2n_ res; + res.a = invert(a).get_lower(); + return res; + } } template<> diff --git a/Math/gf2n.h b/Math/gf2n.h index add8627cf..485d84308 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -65,6 +65,7 @@ class gf2n_ : public ValueInterface static void init_field(int nn = 0); static void init_default(int, bool = false) { init_field(); } + static void init_minimum(int lower); static void reset() { n = 0; } static int degree() { return n; } @@ -213,7 +214,7 @@ class gf2n_short : public gf2n_ static const int DEFAULT_LENGTH = 40; static int length() { return n == 0 ? DEFAULT_LENGTH : n; } - static int default_degree() { return 40; } + static int default_degree() { return DEFAULT_LENGTH; } static void init_field(int nn = 0); diff --git a/Networking/AllButLastPlayer.h b/Networking/AllButLastPlayer.h new file mode 100644 index 000000000..22482c481 --- /dev/null +++ b/Networking/AllButLastPlayer.h @@ -0,0 +1,67 @@ +/* + * AllButZeroPlayer.h + * + */ + +#ifndef NETWORKING_ALLBUTLASTPLAYER_H_ +#define NETWORKING_ALLBUTLASTPLAYER_H_ + +#include "Player.h" + +class AllButLastPlayer : public Player +{ + const Player& P; + Names* N; + +public: + AllButLastPlayer(const Player& P) : + Player(*(N = new Names(P.my_num(), P.num_players() - 1))), P(P) + { + } + + ~AllButLastPlayer() + { + delete N; + } + + void send_to_no_stats(int player, const octetStream& o) const + { + P.send_to(player, o); + } + + void receive_player_no_stats(int i, octetStream& o) const + { + P.receive_player(i, o); + } + + void send_receive_all_no_stats(const vector>& channels, + const vector& to_send, + vector& to_receive) const + { + auto my_channels = channels; + my_channels.resize(P.num_players()); + for (auto& x : my_channels) + x.resize(P.num_players()); + auto my_to_send = to_send; + if (P.my_num() != P.num_players() - 1) + P.send_receive_all(my_channels, my_to_send, to_receive); + to_receive.resize(P.num_players() - 1); + } + + void Broadcast_Receive_no_stats(vector& os) const + { + vector to_send(P.num_players(), os[P.my_num()]); + vector> channels(P.num_players(), + vector(P.num_players(), true)); + for (auto& x: channels) + x.back() = false; + channels.back() = vector(P.num_players(), false); + vector to_receive; + P.send_receive_all(channels, to_send, to_receive); + for (int i = 0; i < P.num_players() - 1; i++) + if (i != P.my_num()) + os[i] = to_receive[i]; + } +}; + +#endif diff --git a/Networking/Player.cpp b/Networking/Player.cpp index b4bab177f..a7935f305 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -146,10 +146,18 @@ void Names::setup_names(const char *servername, int my_port) #endif // Now get the set of names - octetStream os; - os.Receive(socket_num); - os.get(names); - os.get(ports); + try + { + octetStream os; + os.Receive(socket_num); + os.get(names); + os.get(ports); + } + catch (exception& e) + { + throw runtime_error(string("error in network setup: ") + e.what()); + } + if (names.size() != ports.size()) throw runtime_error("invalid network setup"); nplayers = names.size(); @@ -186,6 +194,11 @@ Names::Names(const Names& other) server = 0; } +Names::Names(int my_num, int num_players) : + nplayers(num_players), portnum_base(-1), player_no(my_num), server(0) +{ +} + Names::~Names() { if (server != 0) @@ -817,6 +830,17 @@ void NamedCommStats::print(bool newline) cerr << endl; } +void NamedCommStats::reset() +{ + clear(); + sent = 0; +} + +void PlayerBase::reset_stats() +{ + comm_stats.reset(); +} + NamedCommStats Player::total_comm() const { auto res = comm_stats; diff --git a/Networking/Player.h b/Networking/Player.h index ff4bdcd1d..a547d4795 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -116,7 +116,7 @@ class Names Names(ez::ezOptionParser& opt, int argc, const char** argv, int default_nplayers = 2); - Names() : nplayers(1), portnum_base(-1), player_no(0), server(0) { ; } + Names(int my_num = 0, int num_players = 1); Names(const Names& other); ~Names(); @@ -159,6 +159,7 @@ class NamedCommStats : public map NamedCommStats operator-(const NamedCommStats& other) const; size_t total_data(); void print(bool newline = false); + void reset(); #ifdef VERBOSE_COMM CommStats& operator[](const string& name) { @@ -190,10 +191,19 @@ class PlayerBase virtual int my_num() const = 0; virtual int num_players() const = 0; - virtual void pass_around(octetStream& o, int offset = 1) const = 0; - virtual void Broadcast_Receive(vector& o) const = 0; + virtual void receive_player(int, octetStream&) const + { throw not_implemented(); } + virtual void pass_around(octetStream&, int = 1) const + { throw not_implemented(); } + virtual void Broadcast_Receive(vector&) const + { throw not_implemented(); } virtual void unchecked_broadcast(vector& o) const { Broadcast_Receive(o); } + virtual void send_receive_all(const vector&, + vector&) const + { throw not_implemented(); } + + void reset_stats(); }; /** @@ -230,8 +240,8 @@ class Player : public PlayerBase virtual bool is_encrypted() { return false; } - virtual void send_long(int i, long a) const = 0; - virtual long receive_long(int i) const = 0; + virtual void send_long(int, long) const { throw not_implemented(); } + virtual long receive_long(int) const { throw not_implemented(); } // The following functions generally update the statistics // and then call the *_no_stats equivalent specified by a subclass. @@ -283,7 +293,8 @@ class Player : public PlayerBase * reusing the buffer if possible. */ void exchange(int other, const octetStream& to_send, octetStream& ot_receive) const; - virtual void exchange_no_stats(int other, const octetStream& to_send, octetStream& ot_receive) const = 0; + virtual void exchange_no_stats(int, const octetStream&, octetStream&) const + { throw runtime_error("implement exchange"); } /** * Exchange information with one other party, reusing the buffer. */ @@ -304,8 +315,8 @@ class Player : public PlayerBase * The default is to send to the next party while receiving from the previous. */ void pass_around(octetStream& to_send, octetStream& to_receive, int offset) const; - virtual void pass_around_no_stats(const octetStream& to_send, - octetStream& to_receive, int offset) const = 0; + virtual void pass_around_no_stats(const octetStream&, octetStream&, + int) const { throw runtime_error("implement passing around"); } /** * Broadcast and receive data to/from all players. @@ -317,7 +328,8 @@ class Player : public PlayerBase * Assumes o[player_no] contains the data to be broadcast by me. */ virtual void Broadcast_Receive(vector& o) const; - virtual void Broadcast_Receive_no_stats(vector& o) const = 0; + virtual void Broadcast_Receive_no_stats(vector&) const + { throw runtime_error("implement broadcast"); } /** * Run protocol to verify broadcast is correct diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 732850bc4..5fdbb3d6c 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -15,6 +15,7 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiInput.hpp" #include "Protocols/SemiMC.hpp" +#include "Protocols/mac_key.hpp" #include #include @@ -274,9 +275,9 @@ void NPartyTripleGenerator::generateInputs(int player) inputs.resize(nTriplesPerLoop); typename W::input_check_type::MAC_Check MC(mac_key); - MC.POpen(check_sum, globalPlayer); // use zero element because all is perfectly randomized MC.set_random_element({}); + MC.POpen(check_sum, globalPlayer); MC.Check(globalPlayer); } @@ -673,7 +674,7 @@ void MascotTripleGenerator::sacrifice(typename T::MAC_Check& MC, PRNG& G) auto& outputFile = this->outputFile; auto& uncheckedTriples = this->uncheckedTriples; - assert(T::clear::length() >= 40); + check_field_size(); vector maskedAs(nTriplesPerLoop); vector > maskedTriples(nTriplesPerLoop); @@ -744,6 +745,8 @@ void Spdz2kTripleGenerator::sacrificeZ2k(U& MC, PRNG& G) // and first part of [sigma], i.e., t * [c] - [chat] maskedTriples[j].template prepare_sacrifice(uncheckedTriples[j], G); maskedAs[j] = maskedTriples[j].a[0]; + // enough randomness in values + MC.set_random_element({}); } vector openedAs(nTriplesPerLoop); @@ -754,6 +757,8 @@ void Spdz2kTripleGenerator::sacrificeZ2k(U& MC, PRNG& G) for (int j = 0; j < nTriplesPerLoop; j++) { // compute t * [c] - [chat] - [b] * p sigmas.push_back(maskedTriples[j].computeCheckShare(V(openedAs[j]))); + // enough randomness in values + MC.set_random_element({}); } vector open_sigmas; diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 019fc6f28..33d5f441b 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -59,6 +59,8 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) cerr << "Number of program sequences I need to load = " << nprogs << endl; #endif + bc_filenames.clear(); + // Load in the programs string threadname; for (int i=0; i void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits, false_type) { -#ifndef INSECURE - throw runtime_error("no secure implementation of reading edaBits from files"); -#endif + insecure("reading edaBits from files"); if (edabit_buffers.find(n_bits) == edabit_buffers.end()) { string filename = PrepBase::get_edabit_filename(prep_data_dir, diff --git a/Processor/Input.h b/Processor/Input.h index 728c81f6a..0a84a55f2 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -26,7 +26,7 @@ class InputBase typedef typename T::clear clear; protected: - Player* P; + PlayerBase* P; int my_num; Buffer buffer; @@ -63,7 +63,7 @@ class InputBase /// Initialize input round for ``player`` virtual void reset(int player) = 0; /// Initialize input round for all players - void reset_all(Player& P); + void reset_all(PlayerBase& P); /// Schedule input from me virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 246c9eb1d..09c6e056a 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -81,7 +81,7 @@ void InputBase::reset(int player) } template -void InputBase::reset_all(Player& P) +void InputBase::reset_all(PlayerBase& P) { this->P = &P; my_num = P.my_num(); diff --git a/Processor/Instruction.h b/Processor/Instruction.h index a7e1e3185..5279b2584 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -328,14 +328,14 @@ class BaseInstruction int opcode; // The code int size; // Vector size int r[4]; // Fixed parameter registers - unsigned int n; // Possible immediate value + size_t n; // Possible immediate value vector start; // Values for a start/stop open public: virtual ~BaseInstruction() {}; int get_r(int i) const { return r[i]; } - unsigned int get_n() const { return n; } + size_t get_n() const { return n; } const vector& get_start() const { return start; } int get_opcode() const { return opcode; } int get_size() const { return size; } @@ -350,7 +350,7 @@ class BaseInstruction bool is_direct_memory_access() const; // Returns the memory size used if applicable and known - unsigned get_mem(RegType reg_type) const; + size_t get_mem(RegType reg_type) const; // Returns the maximal register used unsigned get_max_reg(int reg_type) const; diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 1bc46f94f..2a5dce70c 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -218,24 +218,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // instructions with 1 register + 1 integer operand case LDI: case LDSI: - case LDMC: - case LDMS: - case STMC: - case STMS: - case LDMSB: - case STMSB: - case LDMCB: - case STMCB: - case LDMINT: - case STMINT: case JMPNZ: case JMPEQZ: case GLDI: case GLDSI: - case GLDMC: - case GLDMS: - case GSTMC: - case GSTMS: case PRINTREG: case PRINTREGB: case GPRINTREG: @@ -247,6 +233,24 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) r[0]=get_int(s); n = get_int(s); break; + // instructions with 1 register + 1 long operand + case LDMC: + case LDMS: + case STMC: + case STMS: + case LDMSB: + case STMSB: + case LDMCB: + case STMCB: + case LDMINT: + case STMINT: + case GLDMC: + case GLDMS: + case GSTMC: + case GSTMS: + r[0] = get_int(s); + n = get_long(s); + break; // instructions with 1 integer operand case PRINTSTR: case PRINTCHR: @@ -783,7 +787,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const } inline -unsigned BaseInstruction::get_mem(RegType reg_type) const +size_t BaseInstruction::get_mem(RegType reg_type) const { if (get_reg_type() == reg_type and is_direct_memory_access()) return n + size; @@ -843,7 +847,7 @@ inline void Instruction::execute(Processor& Proc) const } int r[3] = {this->r[0], this->r[1], this->r[2]}; - int n = this->n; + int64_t n = this->n; for (int i = 0; i < size; i++) { switch (opcode) { @@ -1065,7 +1069,7 @@ inline void Instruction::execute(Processor& Proc) const case PRINTREG: { Proc.out << "Reg[" << r[0] << "] = " << Proc.read_Cp(r[0]) - << " # " << string((char*)&n,sizeof(n)) << endl; + << " # " << string((char*)&n, 4) << endl; } break; case PRINTREGPLAIN: @@ -1085,7 +1089,7 @@ inline void Instruction::execute(Processor& Proc) const case CONDPRINTSTR: if (not Proc.read_Cp(r[0]).is_zero()) { - string str = {(char*)&n, sizeof(n)}; + string str = {(char*)&n, 4}; size_t n = str.find('\0'); if (n < 4) str.erase(n); @@ -1313,7 +1317,7 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c out << "["; for (int i = 0; i < size; i++) { - if (p == 0) + if (p == 0 or (*p == 0 and s == 0)) out << v[i]; else if (s == 0) out << bigint::get_float(v[i], p[i], {}, {}); diff --git a/Processor/Machine.h b/Processor/Machine.h index 331a9a22c..8b3d018cc 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -46,6 +46,8 @@ class Machine : public BaseMachine void load_program(const string& threadname, const string& filename); + void prepare(const string& progname_str); + void suggest_optimizations(); public: @@ -71,10 +73,10 @@ class Machine : public BaseMachine ExecutionStats stats; - Machine(int my_number, Names& playerNames, const string& progname, - const string& memtype, int lg2, bool direct, int opening_sum, - bool receive_threads, int max_broadcast, bool use_encryption, bool live_prep, - OnlineOptions opts); + static void init_binary_domains(int security_parameter, int lg2); + + Machine(Names& playerNames, bool use_encryption = true, + const OnlineOptions opts = sint(), int lg2 = 0); ~Machine(); const Names& get_N() { return N; } @@ -92,7 +94,11 @@ class Machine : public BaseMachine DataPositions run_tape(int thread_number, int tape_number, int arg, const DataPositions& pos); DataPositions join_tape(int thread_number); - void run(); + + void run(const string& progname); + + void run_step(const string& progname); + pair stop_threads(); string memory_filename(); @@ -102,6 +108,9 @@ class Machine : public BaseMachine void reqbl(int n); typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; } + typename sint::mac_key_type get_sint_mac_key() { return alphapi; } + + Player& get_player() { return *P; } }; #endif /* MACHINE_H_ */ diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index e720b2a99..e0299c2f3 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -24,28 +24,52 @@ using namespace std; template -Machine::Machine(int my_number, Names& playerNames, - const string& progname_str, const string& memtype, - int lg2, bool direct, - int opening_sum, bool receive_threads, int max_broadcast, - bool use_encryption, bool live_prep, OnlineOptions opts) - : my_number(my_number), N(playerNames), - direct(direct), opening_sum(opening_sum), - receive_threads(receive_threads), max_broadcast(max_broadcast), - use_encryption(use_encryption), live_prep(live_prep), opts(opts) +void Machine::init_binary_domains(int security_parameter, int lg2) { + sgf2n::clear::init_field(lg2); + + if (not is_same()) + { + if (sgf2n::clear::degree() < security_parameter) + { + cerr << "Security parameter needs to be at most n in GF(2^n)." + << endl; + cerr << "Increase the latter (-lg2) or decrease the former (-S)." + << endl; + exit(1); + } + } + + if (not is_same()) + { + sint::bit_type::mac_key_type::init_minimum(security_parameter); + } + else + { + // Initialize field for CCD + sint::bit_type::part_type::open_type::init_field(); + } +} + +template +Machine::Machine(Names& playerNames, bool use_encryption, + const OnlineOptions opts, int lg2) + : my_number(playerNames.my_num()), N(playerNames), + direct(opts.direct), opening_sum(opts.opening_sum), + receive_threads(opts.receive_threads), max_broadcast(opts.max_broadcast), + use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts) +{ + OnlineOptions::singleton = opts; + if (opening_sum < 2) this->opening_sum = N.num_players(); if (max_broadcast < 2) this->max_broadcast = N.num_players(); // Set up the fields - sgf2n::clear::init_field(lg2); sint::clear::read_or_generate_setup(prep_dir_prefix(), opts); - sint::bit_type::mac_key_type::init_field(); - // Initialize gf2n_short for CCD - sint::bit_type::part_type::open_type::init_field(); + init_binary_domains(opts.security_parameter, lg2); // make directory for outputs if necessary mkdir_p(PREP_DIR); @@ -75,6 +99,7 @@ Machine::Machine(int my_number, Names& playerNames, sint::clear::next::template init(false); // Initialize the global memory + auto memtype = opts.memtype; if (memtype.compare("old")==0) { ifstream inpf; @@ -92,9 +117,18 @@ Machine::Machine(int my_number, Names& playerNames, { cerr << "Invalid memory argument" << endl; exit(1); } +} +template +void Machine::prepare(const string& progname_str) +{ + int old_n_threads = nthreads; + progs.clear(); load_schedule(progname_str); + // keep preprocessing + nthreads = max(old_n_threads, nthreads); + // initialize persistence if necessary for (auto& prog : progs) { @@ -122,7 +156,7 @@ Machine::Machine(int my_number, Names& playerNames, if (live_prep and (sint::needs_ot or sgf2n::needs_ot or sint::bit_type::needs_ot)) { - for (int i = 0; i < nthreads; i++) + for (int i = old_n_threads; i < nthreads; i++) ot_setups.push_back({ *P, true }); } @@ -132,7 +166,7 @@ Machine::Machine(int my_number, Names& playerNames, queues.resize(nthreads); join_timer.resize(nthreads); - for (int i=0; i::Machine(int my_number, Names& playerNames, } // synchronize with clients before starting timer - for (int i=0; iresult(); } @@ -155,6 +189,9 @@ Machine::Machine(int my_number, Names& playerNames, template Machine::~Machine() { + sint::LivePrep::teardown(); + sgf2n::LivePrep::teardown(); + delete P; for (auto& queue : queues) delete queue; @@ -308,14 +345,12 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, //printf("Running line %d\n",exec); if (progs[tape_number].usage_unknown()) { -#ifndef INSECURE if (not opts.live_prep and thread_number != 0) { - cerr << "Internally called tape " << tape_number << - " has unknown offline data usage" << endl; - throw invalid_program(); + insecure( + "Internally called tape " + to_string(tape_number) + + " has unknown offline data usage"); } -#endif return DataPositions(N.num_players()); } else @@ -336,23 +371,20 @@ DataPositions Machine::join_tape(int i) } template -void Machine::run() +void Machine::run_step(const string& progname) { - Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); - proc_timer.start(); - timer[0].start({}); - - // run main tape + prepare(progname); run_tape(0, 0, 0, N.num_players()); join_tape(0); +} - print_compiler(); - - finish_timer.start(); +template +pair Machine::stop_threads() +{ // Tell all C-threads to stop for (int i=0; ischedule(-1); } @@ -369,6 +401,40 @@ void Machine::run() pos.increase(queues[i]->result().pos); pthread_join(threads[i],NULL); } + + auto comm_stats = total_comm(); + + for (auto& queue : queues) + delete queue; + + queues.clear(); + + nthreads = 0; + + return {pos, comm_stats}; +} + +template +void Machine::run(const string& progname) +{ + prepare(progname); + + Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); + proc_timer.start(); + timer[0].start({}); + + // run main tape + run_tape(0, 0, 0, N.num_players()); + join_tape(0); + + print_compiler(); + + finish_timer.start(); + + // actual usage + auto res = stop_threads(); + DataPositions& pos = res.first; + finish_timer.stop(); #ifdef VERBOSE @@ -387,7 +453,7 @@ void Machine::run() cerr << "Finish timer: " << finish_timer.elapsed() << endl; #endif - NamedCommStats comm_stats = total_comm(); + NamedCommStats& comm_stats = res.second; if (opts.verbose) { @@ -475,17 +541,12 @@ void Machine::run() stats.print(); } -#ifndef INSECURE if (not opts.file_prep_per_thread) { Data_Files df(*this); df.seekg(pos); df.prune(); } -#endif - - sint::LivePrep::teardown(); - sgf2n::LivePrep::teardown(); suggest_optimizations(); diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index 6e0bb525d..b04ea6a15 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -37,13 +37,16 @@ template int OfflineMachine::run() { T::clear::init_default(this->online_opts.prime_length()); - U::clear::init_field(U::clear::default_degree()); - T::bit_type::mac_key_type::init_field(); + Machine::init_binary_domains(this->online_opts.security_parameter, + this->lg2); auto binary_mac_key = read_generate_write_mac_key< typename T::bit_type::part_type>(P); typename T::bit_type::LivePrep bit_prep(usage); GC::ShareThread thread(bit_prep, P, binary_mac_key); + // setup before generation to fix prime + T::LivePrep::basic_setup(P); + generate(); generate(); generate(); diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index e98f1a3a1..b0c5e5795 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -100,6 +100,9 @@ void thread_info::Sub_Main_Func() processor = new Processor(tinfo->thread_num,P,*MC2,*MCp,machine,progs.at(thread_num > 0)); auto& Proc = *processor; + // don't count communication for initialization + P.reset_stats(); + bool flag=true; int program=-3; // int exec=0; @@ -287,10 +290,8 @@ void thread_info::Sub_Main_Func() // final check Proc.check(); -#ifndef INSECURE if (machine.opts.file_prep_per_thread) Proc.DataF.prune(); -#endif wait_timer.start(); queues->next(); diff --git a/Processor/OnlineMachine.h b/Processor/OnlineMachine.h index 9804828eb..68e38e379 100644 --- a/Processor/OnlineMachine.h +++ b/Processor/OnlineMachine.h @@ -17,11 +17,11 @@ class OnlineMachine const char** argv; OnlineOptions& online_opts; - int lg2, opening_sum, max_broadcast; + int lg2; Names playerNames; - bool use_encryption, receive_threads; + bool use_encryption; ez::ezOptionParser& opt; diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 4e944d624..d4c66e9aa 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -18,7 +18,7 @@ template int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_prep_default = true) { OnlineOptions& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv, 1000, live_prep_default, T::clear::invertible}; + online_opts = {opt, argc, argv, T(), live_prep_default}; DishonestMajorityMachine machine(argc, argv, opt, online_opts, typename U::clear()); return machine.run(); @@ -28,8 +28,7 @@ template OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers, V) : argc(argc), argv(argv), online_opts(online_opts), lg2(0), - opening_sum(0), max_broadcast(0), - use_encryption(false), receive_threads(false), + use_encryption(false), opt(opt), nplayers(nplayers) { opt.add( @@ -125,33 +124,6 @@ DishonestMajorityMachine::DishonestMajorityMachine(int argc, const char** argv, opt.example = string() + argv[0] + " -p 0 -N 2 sample-prog\n" + argv[0] + " -h localhost -p 1 -N 2 sample-prog\n"; - opt.add( - "0", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Sum at most n shares at once when using indirect communication", // Help description. - "-s", // Flag token. - "--opening-sum" // Flag token. - ); - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Use player-specific threads for communication", // Help description. - "-t", // Flag token. - "--threads" // Flag token. - ); - opt.add( - "0", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Maximum number of parties to send to at once", // Help description. - "-mb", // Flag token. - "--max-broadcast" // Flag token. - ); opt.add( "", // Default. 0, // Required? @@ -163,11 +135,7 @@ DishonestMajorityMachine::DishonestMajorityMachine(int argc, const char** argv, ); online_opts.finalize(opt, argc, argv); - opt.get("--opening-sum")->getInt(opening_sum); - opt.get("--max-broadcast")->getInt(max_broadcast); - use_encryption = opt.isSet("--encrypted"); - receive_threads = opt.isSet("--threads"); start_networking(); } @@ -230,12 +198,8 @@ int OnlineMachine::run() try #endif { - Machine(online_opts.playerno, playerNames, online_opts.progname, - online_opts.memtype, lg2, - online_opts.direct, opening_sum, - receive_threads, max_broadcast, - use_encryption, online_opts.live_prep, - online_opts).run(); + Machine(playerNames, use_encryption, online_opts, lg2).run( + online_opts.progname); if (online_opts.verbose) { diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 2a5e090bd..d404f642d 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -8,6 +8,7 @@ #include "Math/gfp.h" #include "Math/gfpvar.h" #include "Protocols/HemiOptions.h" +#include "Protocols/config.h" #include "Math/gfp.hpp" @@ -26,10 +27,14 @@ OnlineOptions::OnlineOptions() : playerno(-1) bits_from_squares = false; direct = false; bucket_size = 4; + security_parameter = DEFAULT_SECURITY; cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; file_prep_per_thread = false; - trunc_error = 40; + trunc_error = DEFAULT_SECURITY; + opening_sum = 0; + max_broadcast = 0; + receive_threads = false; #ifdef VERBOSE verbose = true; #else @@ -38,7 +43,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, - const char** argv, false_type) : + const char** argv, bool security) : OnlineOptions() { opt.syntax = std::string(argv[0]) + " [OPTIONS] [] "; @@ -104,6 +109,18 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--bucket-size" // Flag token. ); + if (security) + opt.add( + to_string(security_parameter).c_str(), // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + ("Security parameter (default: " + to_string(security_parameter) + + ")").c_str(), // Help description. + "-S", // Flag token. + "--security" // Flag token. + ); + opt.parse(argc, argv); interactive = opt.isSet("-I"); @@ -117,13 +134,24 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, verbose = opt.isSet("--verbose"); #endif + if (security) + { + opt.get("-S")->getInt(security_parameter); + cerr << "Using security parameter " << security_parameter << endl; + if (security_parameter <= 0) + { + cerr << "Invalid security parameter: " << security_parameter << endl; + exit(1); + } + } + opt.resetArgs(); } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size, bool default_live_prep, - bool variable_prime_length) : - OnlineOptions(opt, argc, argv, false_type()) + bool variable_prime_length, bool security) : + OnlineOptions(opt, argc, argv, security) { if (default_batch_size <= 0) default_batch_size = batch_size; @@ -263,6 +291,9 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, vector badOptions; unsigned int i; + opt.footer += "\nSee also https://mp-spdz.readthedocs.io/en/latest/networking.html " + "for documentation on the networking setup.\n"; + if (allArgs.size() != 3u - opt.isSet("-p")) { cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl; @@ -329,6 +360,16 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, } set_trunc_error(opt); + + auto o = opt.get("--opening-sum"); + if (o) + o->getInt(opening_sum); + + o = opt.get("--max-broadcast"); + if (o) + o->getInt(max_broadcast); + + receive_threads = opt.isSet("--threads"); } void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt) diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 4b2fe4f8c..61c1352bc 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -26,21 +26,26 @@ class OnlineOptions bool bits_from_squares; bool direct; int bucket_size; + int security_parameter; std::string cmd_private_input_file; std::string cmd_private_output_file; bool verbose; bool file_prep_per_thread; int trunc_error; + int opening_sum, max_broadcast; + bool receive_threads; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, - false_type); + bool security); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size = 0, bool default_live_prep = true, - bool variable_prime_length = false); + bool variable_prime_length = false, bool security = true); template OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, T, bool default_live_prep = true); + template + OnlineOptions(T); ~OnlineOptions() {} void finalize(ez::ezOptionParser& opt, int argc, const char** argv); diff --git a/Processor/OnlineOptions.hpp b/Processor/OnlineOptions.hpp index 8961853e5..d8b71cea6 100644 --- a/Processor/OnlineOptions.hpp +++ b/Processor/OnlineOptions.hpp @@ -20,11 +20,49 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Probabilistic truncation error " - "(2^-x, default: 40)", // Help description. + ("Probabilistic truncation error (2^-x, default: " + + to_string(trunc_error) + ")").c_str(), // Help description. "-E", // Flag token. "--trunc-error" // Flag token. ); + + if (T::dishonest_majority) + { + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Sum at most n shares at once when using indirect communication", // Help description. + "-s", // Flag token. + "--opening-sum" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use player-specific threads for communication", // Help description. + "-t", // Flag token. + "--threads" // Flag token. + ); + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Maximum number of parties to send to at once", // Help description. + "-mb", // Flag token. + "--max-broadcast" // Flag token. + ); + } +} + +template +OnlineOptions::OnlineOptions(T) : OnlineOptions() +{ + if (T::dishonest_majority) + batch_size = 1000; } #endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */ diff --git a/Processor/Program.h b/Processor/Program.h index 87a263f08..8fb3df141 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -21,7 +21,7 @@ class Program unsigned max_reg[MAX_REG_TYPE]; // Memory size used directly - unsigned max_mem[MAX_REG_TYPE]; + size_t max_mem[MAX_REG_TYPE]; // True if program contains variable-sized loop bool unknown_usage; @@ -48,7 +48,7 @@ class Program unsigned num_reg(RegType reg_type) const { return max_reg[reg_type]; } - unsigned direct_mem(RegType reg_type) const + size_t direct_mem(RegType reg_type) const { return max_mem[reg_type]; } friend ostream& operator<<(ostream& s,const Program& P); diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index f2bfc6c1b..626942212 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -65,7 +65,7 @@ HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecuri int argc, const char** argv, ez::ezOptionParser& opt) { OnlineOptions online_opts(opt, argc, argv); - RingOptions opts(opt, argc, argv, true); + RingOptions opts(opt, argc, argv); HonestMajorityMachine machine(argc, argv, opt, online_opts); int R = opts.ring_size_from_opts_or_schedule(online_opts.progname); switch (R) @@ -76,15 +76,19 @@ HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecuri break; #define X(K) \ case K: \ - switch (opts.S) \ + { \ + int S = online_opts.security_parameter; \ + switch (S) \ { \ - Y(K, 40) \ + Y(K, DEFAULT_SECURITY) \ default: \ - cerr << "not compiled for security parameter " << to_string(opts.S) << endl; \ - cerr << "add 'Y(K, " << opts.S << ")' to " __FILE__ ", line 76" << endl; \ + cerr << "not compiled for security parameter " << to_string(S) << endl; \ + cerr << "add 'Y(K, " << S << ")' to " __FILE__ ", line 76" << endl; \ + cerr << "or compile with -DDEFAULT_SECURITY=" << S << endl; \ exit(1); \ } \ - break; + break; \ + } X(64) #ifdef RING_SIZE X(RING_SIZE) diff --git a/Processor/RingOptions.cpp b/Processor/RingOptions.cpp index d59a709a8..ec9e9f066 100644 --- a/Processor/RingOptions.cpp +++ b/Processor/RingOptions.cpp @@ -9,8 +9,7 @@ #include using namespace std; -RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, - bool security) +RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv) { opt.add( "64", // Default. @@ -21,28 +20,12 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, "-R", // Flag token. "--ring" // Flag token. ); - if (security) - opt.add( - "40", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Security parameter (default: 40)", // Help description. - "-S", // Flag token. - "--security" // Flag token. - ); opt.parse(argc, argv); opt.get("-R")->getInt(R); - if (security) - opt.get("-S")->getInt(S); - else - S = -1; R_is_set = opt.isSet("-R"); opt.resetArgs(); if (R_is_set) cerr << "Trying to run " << R << "-bit computation" << endl; - if (security) - cerr << "Using security parameter " << S << endl; } int RingOptions::ring_size_from_opts_or_schedule(string progname) diff --git a/Processor/RingOptions.h b/Processor/RingOptions.h index 899c7021a..8f5361f60 100644 --- a/Processor/RingOptions.h +++ b/Processor/RingOptions.h @@ -16,10 +16,8 @@ class RingOptions public: int R; - int S; - RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, - bool security = false); + RingOptions(ez::ezOptionParser& opt, int argc, const char** argv); int ring_size_from_opts_or_schedule(string progname); }; diff --git a/Processor/instructions.h b/Processor/instructions.h index 5928fdabc..bf443b0f7 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -203,7 +203,7 @@ *dest++ = *op1++ == *op2++) \ X(PRINTINT, Proc.out << Proc.read_Ci(r[0]) << flush,) \ X(PRINTFLOATPREC, Proc.out << setprecision(n),) \ - X(PRINTSTR, Proc.out << string((char*)&n,sizeof(n)) << flush,) \ + X(PRINTSTR, Proc.out << string((char*)&n,4) << flush,) \ X(PRINTCHR, Proc.out << string((char*)&n,1) << flush,) \ X(SHUFFLE, shuffle(Proc),) \ X(BITDECINT, bitdecint(Proc),) \ @@ -270,7 +270,7 @@ *dest++ = *op1++ >> n) \ X(GPRINTREG, auto source = &C2[r[0]], \ Proc.out << "Reg[" << r[0] << "] = " << *source++ \ - << " # " << string((char*)&n,sizeof(n)) << endl) \ + << " # " << string((char*)&n, 4) << endl) \ X(GPRINTREGPLAIN, auto source = &C2[r[0]], \ Proc.out << *source++ << flush) \ X(GBITDEC, gbitdec(C2),) \ diff --git a/Programs/Source/l2h_comparison.mpc b/Programs/Source/l2h_comparison.mpc new file mode 100644 index 000000000..c233caa77 --- /dev/null +++ b/Programs/Source/l2h_comparison.mpc @@ -0,0 +1,3 @@ +res = sint.load_mem(0) < sint.load_mem(1) +res.store_in_mem(3) +print_ln('comparison in VM: %s', res.reveal()) diff --git a/Programs/Source/l2h_multiplication.mpc b/Programs/Source/l2h_multiplication.mpc new file mode 100644 index 000000000..aecbca651 --- /dev/null +++ b/Programs/Source/l2h_multiplication.mpc @@ -0,0 +1 @@ +(sint.load_mem(0) * sint.load_mem(1)).store_in_mem(2) diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 2d28127c7..9b695d0d1 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -23,6 +23,7 @@ class Player; template class Beaver : public ProtocolBase { +protected: vector shares; vector opened; vector> triples; diff --git a/Protocols/ChaiGearPrep.hpp b/Protocols/ChaiGearPrep.hpp index 69b16fcf2..076931787 100644 --- a/Protocols/ChaiGearPrep.hpp +++ b/Protocols/ChaiGearPrep.hpp @@ -43,15 +43,16 @@ void ChaiGearPrep::basic_setup(Player& P) assert(machine == 0); machine = new MultiplicativeMachine; auto& setup = machine->setup.part(); - auto& options = CowGearOptions::singleton; + int lowgear_security = OnlineOptions::singleton.security_parameter; #ifdef VERBOSE + auto& options = CowGearOptions::singleton; cerr << "Covert security parameter for key and MAC generation: " << options.covert_security << endl; cerr << "Triple generation security parameter: " - << options.lowgear_security << endl; + << lowgear_security << endl; #endif - machine->sec = options.lowgear_security; - setup.secure_init(P, *machine, T::clear::length(), options.lowgear_security); + machine->sec = lowgear_security; + setup.secure_init(P, *machine, T::clear::length(), lowgear_security); T::clear::template init(); #ifdef VERBOSE cerr << T::type_string() << " parameter setup took " << timer.elapsed() diff --git a/Protocols/CowGearOptions.cpp b/Protocols/CowGearOptions.cpp index 9212a7e06..e018dd8bd 100644 --- a/Protocols/CowGearOptions.cpp +++ b/Protocols/CowGearOptions.cpp @@ -23,7 +23,6 @@ CowGearOptions::CowGearOptions(bool covert) covert_security = -1; } - lowgear_security = 40; use_top_gear = false; } @@ -49,7 +48,7 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "LowGear security parameter (default: 40)", // Help description. + "DEPRECATED: use -S/--security", // Help description. "-l", // Flag token. "--lowgear-security" // Flag token. ); @@ -76,15 +75,8 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, opt.get("-c")->getInt(covert_security); if (opt.isSet("-l")) { - opt.get("-l")->getInt(lowgear_security); - if (lowgear_security <= 0) - { - throw exception(); - cerr << "Invalid LowGear Security parameter: " << lowgear_security << endl; - exit(1); - } - if (covert_security > (1LL << lowgear_security)) - insecure(", LowGear security less than key generation security"); + cerr << "Deprecated parameter, use -S/--security" << endl; + exit(1); } use_top_gear = not opt.isSet("-J"); if (opt.isSet("-T")) diff --git a/Protocols/CowGearOptions.h b/Protocols/CowGearOptions.h index f79bd5212..af9006dcc 100644 --- a/Protocols/CowGearOptions.h +++ b/Protocols/CowGearOptions.h @@ -16,7 +16,6 @@ class CowGearOptions static CowGearOptions singleton; int covert_security; - int lowgear_security; CowGearOptions(bool covert = true); CowGearOptions(ez::ezOptionParser& opt, int argc, const char** argv, diff --git a/Protocols/CowGearPrep.hpp b/Protocols/CowGearPrep.hpp index 3b9daae1e..36b36d663 100644 --- a/Protocols/CowGearPrep.hpp +++ b/Protocols/CowGearPrep.hpp @@ -38,14 +38,15 @@ void CowGearPrep::basic_setup(Player& P) pairwise_machine = new PairwiseMachine(P); auto& machine = *pairwise_machine; auto& setup = machine.setup(); - auto& options = CowGearOptions::singleton; + int lowgear_security = OnlineOptions::singleton.security_parameter; #ifdef VERBOSE + auto& options = CowGearOptions::singleton; if (T::covert) cerr << "Covert security parameter for key and MAC generation: " << options.covert_security << endl; - cerr << "LowGear security parameter: " << options.lowgear_security << endl; + cerr << "LowGear security parameter: " << lowgear_security << endl; #endif - setup.secure_init(P, machine, T::clear::length(), options.lowgear_security); + setup.secure_init(P, machine, T::clear::length(), lowgear_security); T::clear::template init(); #ifdef VERBOSE cerr << T::type_string() << " parameter setup took " << timer.elapsed() diff --git a/Protocols/DabitSacrifice.h b/Protocols/DabitSacrifice.h index 3b436547d..6da8cc238 100644 --- a/Protocols/DabitSacrifice.h +++ b/Protocols/DabitSacrifice.h @@ -9,10 +9,12 @@ template class DabitSacrifice { - static const int S = 40; + const int S; public: - static int minimum_n_inputs(int n_outputs = 0) + DabitSacrifice(); + + int minimum_n_inputs(int n_outputs = 0) { if (n_outputs < 1) n_outputs = OnlineOptions::singleton.batch_size; diff --git a/Protocols/DabitSacrifice.hpp b/Protocols/DabitSacrifice.hpp index aa2abf615..74d9f0267 100644 --- a/Protocols/DabitSacrifice.hpp +++ b/Protocols/DabitSacrifice.hpp @@ -11,6 +11,12 @@ #include +template +DabitSacrifice::DabitSacrifice() : + S(OnlineOptions::singleton.security_parameter) +{ +} + template dabit& operator+=(dabit& x, const dabit& y) { diff --git a/Protocols/DealerInput.h b/Protocols/DealerInput.h new file mode 100644 index 000000000..7d0699da4 --- /dev/null +++ b/Protocols/DealerInput.h @@ -0,0 +1,38 @@ +/* + * DealerInput.h + * + */ + +#ifndef PROTOCOLS_DEALERINPUT_H_ +#define PROTOCOLS_DEALERINPUT_H_ + +#include "../Networking/AllButLastPlayer.h" +#include "Processor/Input.h" + +template +class DealerInput : public InputBase +{ + Player& P; + octetStreams to_send, to_receive; + SeededPRNG G; + vector> shares; + bool from_dealer; + AllButLastPlayer sub_player; + SemiInput>* internal; + +public: + DealerInput(SubProcessor& proc, typename T::MAC_Check&); + DealerInput(typename T::MAC_Check&, Preprocessing&, Player& P); + DealerInput(Player& P); + ~DealerInput(); + + bool is_dealer(int player = -1); + + void reset(int player); + void add_mine(const typename T::open_type& input, int n_bits = -1); + void add_other(int player, int n_bits = -1); + void exchange(); + T finalize(int player, int n_bits = -1); +}; + +#endif /* PROTOCOLS_DEALERINPUT_H_ */ diff --git a/Protocols/DealerInput.hpp b/Protocols/DealerInput.hpp new file mode 100644 index 000000000..26bfb9a1a --- /dev/null +++ b/Protocols/DealerInput.hpp @@ -0,0 +1,115 @@ +/* + * DealerInput.hpp + * + */ + +#ifndef PROTOCOLS_DEALERINPUT_HPP_ +#define PROTOCOLS_DEALERINPUT_HPP_ + +#include "DealerInput.h" + +template +DealerInput::DealerInput(SubProcessor& proc, typename T::MAC_Check&) : + DealerInput(proc.P) +{ +} + +template +DealerInput::DealerInput(typename T::MAC_Check&, Preprocessing&, + Player& P) : + DealerInput(P) +{ +} + +template +DealerInput::DealerInput(Player& P) : + P(P), to_send(P), shares(P.num_players()), from_dealer(false), + sub_player(P) +{ + if (is_dealer()) + internal = 0; + else + internal = new SemiInput>(0, sub_player); +} + +template +DealerInput::~DealerInput() +{ + if (internal) + delete internal; +} + +template +bool DealerInput::is_dealer(int player) +{ + int dealer_player = P.num_players() - 1; + if (player == -1) + return P.my_num() == dealer_player; + else + return player == dealer_player; +} + +template +void DealerInput::reset(int player) +{ + if (player == 0) + { + to_send.reset(P); + from_dealer = false; + } + else if (not is_dealer()) + internal->reset(player - 1); +} + +template +void DealerInput::add_mine(const typename T::open_type& input, + int) +{ + if (is_dealer()) + { + make_share(shares.data(), input, P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(to_send[i]); + from_dealer = true; + } + else + internal->add_mine(input); +} + +template +void DealerInput::add_other(int player, int) +{ + if (is_dealer(player)) + from_dealer = true; + else if (not is_dealer()) + internal->add_other(player); +} + +template +void DealerInput::exchange() +{ + if (from_dealer) + { + vector senders(P.num_players()); + senders.back() = true; + P.send_receive_all(senders, to_send, to_receive); + } + else if (not is_dealer()) + internal->exchange(); +} + +template +T DealerInput::finalize(int player, int) +{ + if (is_dealer()) + return {}; + else + { + if (is_dealer(player)) + return to_receive.back().template get(); + else + return internal->finalize(player); + } +} + +#endif /* PROTOCOLS_DEALERINPUT_HPP_ */ diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h new file mode 100644 index 000000000..5311f8132 --- /dev/null +++ b/Protocols/DealerMC.h @@ -0,0 +1,42 @@ +/* + * DealerMC.h + * + */ + +#ifndef PROTOCOLS_DEALERMC_H_ +#define PROTOCOLS_DEALERMC_H_ + +#include "MAC_Check_Base.h" +#include "Networking/AllButLastPlayer.h" + +template +class DealerMC : public MAC_Check_Base +{ + typedef SemiMC> internal_type; + internal_type& internal; + AllButLastPlayer* sub_player; + +public: + DealerMC(typename T::mac_key_type = {}, int = 0, int = 0); + DealerMC(internal_type& internal); + ~DealerMC(); + + void init_open(const Player& P, int n = 0); + void prepare_open(const T& secret); + void exchange(const Player& P); + typename T::open_type finalize_raw(); + + DealerMC& get_part_MC() + { + return *this; + } +}; + +template +class DirectDealerMC : public DealerMC +{ +public: + DirectDealerMC(typename T::mac_key_type = {}); +}; + +#endif /* PROTOCOLS_DEALERMC_H_ */ diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp new file mode 100644 index 000000000..a9ddc035c --- /dev/null +++ b/Protocols/DealerMC.hpp @@ -0,0 +1,76 @@ +/* + * DealerMC.hpp + * + */ + +#ifndef PROTOCOLS_DEALERMC_HPP_ +#define PROTOCOLS_DEALERMC_HPP_ + +#include "DealerMC.h" + +template +DealerMC::DealerMC(typename T::mac_key_type, int, int) : + DealerMC(*(new internal_type)) +{ +} + +template +DirectDealerMC::DirectDealerMC(typename T::mac_key_type) : + DealerMC(*(new DirectSemiMC>)) +{ +} + +template +DealerMC::DealerMC(internal_type& internal) : + internal(internal), sub_player(0) +{ +} + +template +DealerMC::~DealerMC() +{ + delete &internal; + if (sub_player) + delete sub_player; +} + +template +void DealerMC::init_open(const Player& P, int n) +{ + if (P.my_num() != P.num_players() - 1) + { + if (not sub_player) + sub_player = new AllButLastPlayer(P); + internal.init_open(P, n); + } +} + +template +void DealerMC::prepare_open(const T& secret) +{ + if (sub_player) + internal.prepare_open(secret); + else + { + if (secret != T()) + throw runtime_error("share for dealer should be 0"); + } +} + +template +void DealerMC::exchange(const Player&) +{ + if (sub_player) + internal.exchange(*sub_player); +} + +template +typename T::open_type DealerMC::finalize_raw() +{ + if (sub_player) + return internal.finalize_raw(); + else + return {}; +} + +#endif /* PROTOCOLS_DEALERMC_HPP_ */ diff --git a/Protocols/DealerPrep.h b/Protocols/DealerPrep.h new file mode 100644 index 000000000..ae28ec691 --- /dev/null +++ b/Protocols/DealerPrep.h @@ -0,0 +1,33 @@ +/* + * DealerPrep.h + * + */ + +#ifndef PROTOCOLS_DEALERPREP_H_ +#define PROTOCOLS_DEALERPREP_H_ + +#include "ReplicatedPrep.h" + +template +class DealerPrep : virtual public BitPrep +{ + template + void buffer_edabits(int n_bits, true_type); + template + void buffer_edabits(int n_bits, false_type); + +public: + DealerPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage) + { + } + + void buffer_triples(); + void buffer_bits(); + + void buffer_dabits(ThreadQueues* = 0); + void buffer_edabits(int n_bits, ThreadQueues*); + void buffer_sedabits(int n_bits, ThreadQueues*); +}; + +#endif /* PROTOCOLS_DEALERPREP_H_ */ diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp new file mode 100644 index 000000000..d4a0a91dd --- /dev/null +++ b/Protocols/DealerPrep.hpp @@ -0,0 +1,196 @@ +/* + * DealerPrep.hpp + * + */ + +#ifndef PROTOCOLS_DEALERPREP_HPP_ +#define PROTOCOLS_DEALERPREP_HPP_ + +#include "DealerPrep.h" + +template +void DealerPrep::buffer_triples() +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T triples[3]; + for (int i = 0; i < 2; i++) + triples[i] = G.get(); + triples[2] = triples[0] * triples[1]; + for (auto& value : triples) + { + make_share(shares.data(), typename T::clear(value), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + } + this->triples.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->triples.push_back(to_receive.back().get>().get()); + } +} + +template +void DealerPrep::buffer_bits() +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T bit = G.get_bit(); + make_share(shares.data(), typename T::clear(bit), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + this->bits.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->bits.push_back(to_receive.back().get()); + } +} + +template +void DealerPrep::buffer_dabits(ThreadQueues*) +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + vector bit_shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + auto bit = G.get_bit(); + make_share(shares.data(), typename T::clear(bit), + P.num_players() - 1, 0, G); + make_share(bit_shares.data(), typename T::bit_type::clear(bit), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + { + shares.at(i - 1).pack(os[i - 1]); + bit_shares.at(i - 1).pack(os[i - 1]); + } + this->dabits.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + this->dabits.push_back({to_receive.back().get(), + to_receive.back().get()}); + } + } +} + +template +void DealerPrep::buffer_sedabits(int length, ThreadQueues*) +{ + auto& buffer = this->edabits[{false, length}]; + if (buffer.empty()) + buffer_edabits(length, 0); + this->edabits[{true, length}].push_back(buffer.back()); + buffer.pop_back(); +} + +template +void DealerPrep::buffer_edabits(int length, ThreadQueues*) +{ + buffer_edabits(length, T::clear::characteristic_two); +} + +template +template +void DealerPrep::buffer_edabits(int, true_type) +{ + throw not_implemented(); +} + +template +template +void DealerPrep::buffer_edabits(int length, false_type) +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + int n_vecs = OnlineOptions::singleton.batch_size / edabitvec::MAX_SIZE; + auto& buffer = this->edabits[{false, length}]; + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + vector bit_shares(P.num_players() - 1); + for (int i = 0; i < n_vecs; i++) + { + vector as; + vector bs; + plain_edabits(as, bs, length, G); + for (auto& a : as) + { + make_share(shares.data(), a, P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + } + for (auto& b : bs) + { + make_share(bit_shares.data(), b, P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + bit_shares.at(i - 1).pack(os[i - 1]); + } + buffer.push_back({}); + buffer.back().a.resize(edabitvec::MAX_SIZE); + buffer.back().b.resize(length); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < n_vecs; i++) + { + buffer.push_back({}); + for (int j = 0; j < edabitvec::MAX_SIZE; j++) + buffer.back().a.push_back(to_receive.back().get()); + for (int j = 0; j < length; j++) + buffer.back().b.push_back( + to_receive.back().get()); + } + } +} + +#endif /* PROTOCOLS_DEALERPREP_HPP_ */ diff --git a/Protocols/DealerShare.h b/Protocols/DealerShare.h new file mode 100644 index 000000000..38900ff37 --- /dev/null +++ b/Protocols/DealerShare.h @@ -0,0 +1,76 @@ +/* + * DealerShare.h + * + */ + +#ifndef PROTOCOLS_DEALERSHARE_H_ +#define PROTOCOLS_DEALERSHARE_H_ + +#include "Math/Z2k.h" +#include "SemiShare.h" + +template class DealerPrep; +template class DealerInput; +template class DealerMC; +template class DirectDealerMC; + +namespace GC +{ +class DealerSecret; +} + +template +class DealerShare : public SemiShare +{ + typedef DealerShare This; + typedef SemiShare super; + +public: + typedef GC::DealerSecret bit_type; + + typedef DealerMC MAC_Check; + typedef DirectDealerMC Direct_MC; + typedef Beaver Protocol; + typedef DealerInput Input; + typedef DealerPrep LivePrep; + typedef ::PrivateOutput PrivateOutput; + + static false_type dishonest_majority; + const static bool needs_ot = false; + + static string type_short() + { + return "DD" + string(1, T::type_char()); + } + + static int threshold(int) + { + throw runtime_error("undefined threshold"); + } + + static This constant(const T& other, int my_num, + const typename super::mac_key_type& = {}, int = -1) + { + if (my_num == 1) + return other; + else + return {}; + } + + DealerShare() + { + } + + template + DealerShare(const U& other) : super(other) + { + } +}; + +template +using DealerRingShare = DealerShare>; + +template +false_type DealerShare::dishonest_majority; + +#endif /* PROTOCOLS_DEALERSHARE_H_ */ diff --git a/Protocols/FakeMC.h b/Protocols/FakeMC.h index d16dda1cf..b5876ec05 100644 --- a/Protocols/FakeMC.h +++ b/Protocols/FakeMC.h @@ -12,7 +12,7 @@ template class FakeMC : public MAC_Check_Base { public: - FakeMC(T, int = 0, int = 0) + FakeMC(typename T::mac_key_type, int = 0, int = 0) { } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index fb55f0cf4..018ac3384 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -28,6 +28,7 @@ class FakeProtocol : public ProtocolBase vector trunc_stats; map cisc_stats; + map ltz_stats; public: Player& P; @@ -54,6 +55,8 @@ class FakeProtocol : public ProtocolBase { cerr << x.second << " " << x.first << endl; } + for (auto& x : ltz_stats) + cerr << "LTZ " << x.first << ": " << x.second << endl; } template @@ -219,6 +222,7 @@ class FakeProtocol : public ProtocolBase { for (size_t i = 0; i < args.size(); i += args[i]) { + ltz_stats[args[i + 4]] += args[i + 1]; assert(i + args[i] <= args.size()); assert(args[i] == 6); for (int j = 0; j < args[i + 1]; j++) diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index 569c136e6..c0a269d1a 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -19,7 +19,6 @@ class FakeShare : public T, public ShareInterface typedef FakeShare This; public: - typedef T mac_key_type; typedef T open_type; typedef T clear; @@ -45,7 +44,7 @@ class FakeShare : public T, public ShareInterface return 0; } - static T constant(T value, int = 0, T = 0) + static T constant(T value, int = 0, mac_key_type = {}) { return value; } diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 8a00c793c..f43260ea1 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -16,6 +16,7 @@ template class Hemi : public Semi { map, HemiMatrixPrep*> matrix_preps; + DataPositions matrix_usage; ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor); diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 1eebd3b73..1b3d8f5ba 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -27,7 +27,8 @@ HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, new HemiMatrixPrep(dims[0], dims[1], dims[2], - dynamic_cast(processor.DataF))}); + dynamic_cast(processor.DataF), + matrix_usage)}); return *matrix_preps.at(dims); } diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index ea5a7211c..8038e8efc 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -22,15 +22,15 @@ class HemiMatrixPrep : public BufferPrep> int n_rows, n_inner, n_cols; bool swapped; - DataPositions* usage; LivePrep* prep; HemiMatrixPrep(const HemiMatrixPrep&) = delete; public: - HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) : - super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner), + HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep, + DataPositions& usage) : + super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), prep(&prep) { swapped = n_rows > n_cols; @@ -39,11 +39,6 @@ class HemiMatrixPrep : public BufferPrep> assert(this->n_cols >= this->n_rows); } - ~HemiMatrixPrep() - { - delete usage; - } - void set_protocol(typename ShareMatrix::Protocol&) { } diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index f42212995..b2dd92d21 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -5,6 +5,7 @@ #include "HemiMatrixPrep.h" #include "FHE/Diagonalizer.h" +#include "Tools/Bundle.h" class CipherPlainMultJob : public ThreadJob { diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index c456424e5..ce55bce75 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -30,7 +30,9 @@ void HemiPrep::basic_setup(Player& P) pairwise_machine = new PairwiseMachine(P); auto& machine = *pairwise_machine; auto& setup = machine.setup(); - setup.secure_init(P, machine, T::clear::length(), 40); + setup.params.set_matrix_dim_from_options(); + setup.params.set_sec(OnlineOptions::singleton.security_parameter); + setup.secure_init(P, machine, T::clear::length(), 0); T::clear::template init(); } diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 2250417d0..19d5e72d5 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -122,7 +122,6 @@ template class MAC_Check_Z2k : public Tree_MAC_Check { protected: - vector shares; Preprocessing* prep; W get_random_element(); @@ -130,11 +129,11 @@ class MAC_Check_Z2k : public Tree_MAC_Check public: vector random_elements; - void AddToCheck(const W& share, const T& value, const Player& P); MAC_Check_Z2k(const T& ai, int opening_sum=10, int max_broadcast=10, int send_player=0); MAC_Check_Z2k(const T& ai, Names& Nms, int thread_num); void prepare_open(const W& secret); + void prepare_open_no_mask(const W& secret); virtual void Check(const Player& P); void set_random_element(const W& random_element); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index fd71d5269..ca607fd76 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -14,6 +14,7 @@ #include #include "Protocols/MAC_Check_Base.hpp" +#include "mac_key.hpp" template const char* TreeSum::mc_timer_names[] = { @@ -118,6 +119,7 @@ template void MAC_Check_::Check(const Player& P) { assert(U::mac_type::invertible); + check_field_size(); if (this->WaitingForCheck() == 0) return; @@ -214,17 +216,15 @@ MAC_Check_Z2k::MAC_Check_Z2k(const T& ai, Names& Nms, } template -void MAC_Check_Z2k::AddToCheck(const W& share, const T& value, const Player& P) +void MAC_Check_Z2k::prepare_open(const W& secret) { - shares.push_back(share.get_share()); - Tree_MAC_Check::AddToCheck(share, value, P); + prepare_open_no_mask(secret + (get_random_element() << W::clear::N_BITS)); } template -void MAC_Check_Z2k::prepare_open(const W& secret) +void MAC_Check_Z2k::prepare_open_no_mask(const W& secret) { - shares.push_back(secret.get_share()); - this->values.push_back(V(secret.get_share())); + this->values.push_back(secret.get_share()); this->macs.push_back(secret.get_mac()); } @@ -269,7 +269,6 @@ void MAC_Check_Z2k::Check(const Player& P) cout << "Checking " << shares[0] << " " << this->vals[0] << " " << this->macs[0] << endl; #endif - int k = V::N_BITS; octet seed[SEED_SIZE]; Create_Random_Seed(seed,P,SEED_SIZE); PRNG G; @@ -290,30 +289,7 @@ void MAC_Check_Z2k::Check(const Player& P) chi.push_back(temp_chi); } - W r = get_random_element(); - T lj = r.get_mac(); - U pj; - pj.assign_zero(); - for (int i = 0; i < this->popen_cnt; ++i) - { - T xji = shares[i]; - V xbarji = xji; - U pji = U((xji - xbarji) >> k); - pj += chi[i] * pji; - } - pj += U(r.get_share()); - - U pbar(pj); - vector pj_stream(P.num_players()); - pj.pack(pj_stream[P.my_num()]); - P.unchecked_broadcast(pj_stream); - for (int j=0; jalphai * y) - (((this->alphai * pbar)) << k) + (lj << k); + T zj = mj - this->alphai * y; vector zjs(P.num_players()); zjs[P.my_num()] = zj; Commit_And_Open(zjs, P); @@ -325,7 +301,6 @@ void MAC_Check_Z2k::Check(const Player& P) this->vals.erase(this->vals.begin(), this->vals.begin() + this->popen_cnt); this->macs.erase(this->macs.begin(), this->macs.begin() + this->popen_cnt); - this->shares.erase(this->shares.begin(), this->shares.begin() + this->popen_cnt); this->popen_cnt=0; if (!zj_sum.is_zero()) { throw mac_fail(); } } diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index 5a60281c6..e855214fd 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -57,7 +57,8 @@ class MAC_Check_Base /// Run opening protocol virtual void exchange(const Player& P) = 0; /// Get next opened value - virtual typename T::open_type finalize_open(); + virtual typename T::clear finalize_open(); + virtual typename T::open_type finalize_raw(); /// Check whether all ``shares`` are ``value`` virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 91e2ea86b..59c6c5dec 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -25,7 +25,7 @@ void MAC_Check_Base::POpen_End(vector& values, values.clear(); values.reserve(S.size()); for (size_t i = 0; i < S.size(); i++) - values.push_back(finalize_open()); + values.push_back(finalize_raw()); } template @@ -59,7 +59,13 @@ void MAC_Check_Base::prepare_open(const T& secret) } template -typename T::open_type MAC_Check_Base::finalize_open() +typename T::clear MAC_Check_Base::finalize_open() +{ + return finalize_raw(); +} + +template +typename T::open_type MAC_Check_Base::finalize_raw() { return values.next(); } diff --git a/Protocols/MalRepRingOptions.cpp b/Protocols/MalRepRingOptions.cpp index a2537da66..c5aafc18b 100644 --- a/Protocols/MalRepRingOptions.cpp +++ b/Protocols/MalRepRingOptions.cpp @@ -21,10 +21,10 @@ MalRepRingOptions::MalRepRingOptions(ez::ezOptionParser& opt, int argc, 0, // Number of args expected. 0, // Delimiter if expecting multiple args. "Shuffle sacrifice (default: disabled)", // Help description. - "-S", // Flag token. + "-SH", // Flag token. "--shuffle" // Flag token. ); opt.parse(argc, argv); - shuffle = opt.isSet("-S"); + shuffle = opt.isSet("-SH"); opt.resetArgs(); } diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 96f2c8138..6ce2e2442 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -89,7 +89,7 @@ void MalRepRingPrep::simple_buffer_triples() template void MalRepRingPrep::shuffle_buffer_triples() { - assert(T::SECURITY <= 40); + assert(T::SECURITY <= OnlineOptions::singleton.security_parameter); assert(this->proc != 0); typename T::MAC_Check MC; shuffle_triple_generation(this->triples, this->proc->P, MC); diff --git a/Protocols/MaliciousRepMC.h b/Protocols/MaliciousRepMC.h index 87deaaa3b..e023945b1 100644 --- a/Protocols/MaliciousRepMC.h +++ b/Protocols/MaliciousRepMC.h @@ -49,11 +49,11 @@ class HashMaliciousRepMC : public MaliciousRepMC public: // emulate MAC_Check - HashMaliciousRepMC(const typename T::value_type& _, int __ = 0, int ___ = 0) : HashMaliciousRepMC() + HashMaliciousRepMC(const typename T::mac_key_type& _, int __ = 0, int ___ = 0) : HashMaliciousRepMC() { (void)_; (void)__; (void)___; } // emulate Direct_MAC_Check - HashMaliciousRepMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) : HashMaliciousRepMC() + HashMaliciousRepMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) : HashMaliciousRepMC() { (void)_; (void)__; (void)___; (void)____; } HashMaliciousRepMC(); @@ -62,7 +62,7 @@ class HashMaliciousRepMC : public MaliciousRepMC void POpen(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); - virtual typename T::open_type finalize_open(); + virtual typename T::open_type finalize_raw(); void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index e64db21ad..17eec6f11 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -84,9 +84,9 @@ void HashMaliciousRepMC::POpen_End(vector& values, } template -typename T::open_type HashMaliciousRepMC::finalize_open() +typename T::open_type HashMaliciousRepMC::finalize_raw() { - auto res = ReplicatedMC::finalize_open(); + auto res = ReplicatedMC::finalize_raw(); os.reset_write_head(); res.pack(os); update(); diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 8ffbff7bd..b4d83d1b9 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -7,6 +7,8 @@ #include "Tools/Subroutines.h" #include "Processor/OnlineOptions.h" +#include "mac_key.hpp" + template MaliciousBitOnlyRepPrep::MaliciousBitOnlyRepPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), @@ -69,7 +71,7 @@ void MaliciousBitOnlyRepPrep::init_honest(Player& P) template void MaliciousRepPrep::buffer_triples() { - assert(T::open_type::length() >= 40); + check_field_size(); auto& triples = this->triples; auto buffer_size = this->buffer_size; auto& honest_proc = this->honest_proc; diff --git a/Protocols/MaliciousShamirMC.h b/Protocols/MaliciousShamirMC.h index a72c36b03..4efe07118 100644 --- a/Protocols/MaliciousShamirMC.h +++ b/Protocols/MaliciousShamirMC.h @@ -38,7 +38,7 @@ class MaliciousShamirMC : public ShamirMC { (void)_; (void)__; (void)___; (void)____; } void init_open(const Player& P, int n = 0); - typename T::open_type finalize_open(); + typename T::open_type finalize_raw(); typename T::open_type reconstruct(const vector& shares); }; diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index 41c6f208a..7f66215d5 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -33,7 +33,7 @@ void MaliciousShamirMC::init_open(const Player& P, int n) } template -typename T::open_type MaliciousShamirMC::finalize_open() +typename T::open_type MaliciousShamirMC::finalize_raw() { int threshold = ShamirMachine::s().threshold; shares.resize(2 * threshold + 1); diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index fee8e8292..ceedc9157 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -38,6 +38,9 @@ class MaliciousShamirShare : public ShamirShare typedef MaliciousRepPrep TriplePrep; typedef T random_type; + // indicate security relevance of field size + typedef T mac_key_type; + #ifndef NO_MIXED_CIRCUITS typedef GC::MaliciousCcdSecret bit_type; #endif diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index c9eb63cf6..11942825f 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -25,13 +25,15 @@ template void MamaPrep::buffer_triples() { int mac_security = T::N_MACS * T::clear::length(); + int sec = OnlineOptions::singleton.security_parameter; - if (mac_security < 40) + if (mac_security < sec) { - cerr << T::N_MACS << " MACs are not enough for 40-bit security with " - << T::clear::length() << "-bit primes." << endl; + cerr << T::N_MACS << " MACs are not enough for " << sec + << "-bit security with " << T::clear::length() << "-bit primes." + << endl; cerr << "Compile with -DN_MAMA_MACS=" - << DIV_CEIL(40, T::clear::length()) + << DIV_CEIL(sec, T::clear::length()) << " or remove this check in " << __FILE__ << endl; exit(1); } @@ -45,7 +47,7 @@ void MamaPrep::buffer_triples() size_t required = OnlineOptions::singleton.batch_size; // prefer shuffling if not loosing much security and bucket size is smaller - bool use_shuffling = mac_security <= 42 + bool use_shuffling = mac_security <= (sec + 2) and OnlineOptions::singleton.bucket_size < T::N_MACS; if (use_shuffling) required = sacrifice.minimum_n_inputs(); diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index c90a5e277..f27515185 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -23,6 +23,11 @@ class MamaMac : public FixedVec, N> public: static const true_type invertible; + static int length() + { + return N * T::length(); + } + MamaMac() { } diff --git a/Protocols/NoLivePrep.h b/Protocols/NoLivePrep.h index c53ec7e84..a1b89f9d9 100644 --- a/Protocols/NoLivePrep.h +++ b/Protocols/NoLivePrep.h @@ -32,6 +32,11 @@ class NoLivePrep : public BufferPrep { } + NoLivePrep(DataPositions& usage, int = -1) : + BufferPrep(usage) + { + } + // access to protocol instance if needed void set_protocol(typename T::Protocol&) { diff --git a/Protocols/NoProtocol.h b/Protocols/NoProtocol.h index d8259eb0f..f1ef3c02e 100644 --- a/Protocols/NoProtocol.h +++ b/Protocols/NoProtocol.h @@ -8,6 +8,7 @@ #include "Protocols/Replicated.h" #include "Protocols/MAC_Check_Base.h" +#include "Processor/Input.h" // opening facility template diff --git a/Protocols/NoShare.h b/Protocols/NoShare.h index d966f5867..0532200b6 100644 --- a/Protocols/NoShare.h +++ b/Protocols/NoShare.h @@ -25,9 +25,6 @@ class NoShare : public ShareInterface typedef T clear; typedef clear open_type; - // needs to be defined even if protocol doesn't use MACs - typedef clear mac_key_type; - // disable binary computation typedef GC::NoShare bit_type; diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 44853b79a..afb456621 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -27,8 +27,6 @@ class RepShare : public FixedVec, public ShareInterface public: typedef T clear; typedef T open_type; - typedef T mac_type; - typedef T mac_key_type; const static bool needs_ot = false; const static bool dishonest_majority = false; @@ -138,9 +136,10 @@ class Rep3Share : public RepShare return T::type_char(); } - static Rep3Share constant(T value, int my_num, const T& alphai = {}) + static Rep3Share constant(T value, int my_num, + typename super::mac_key_type = {}) { - return Rep3Share(value, my_num, alphai); + return Rep3Share(value, my_num); } Rep3Share() diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 2357d0f5e..ba5b85c8e 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -53,6 +53,8 @@ class ProtocolBase int trunc_pr_counter; int rounds, trunc_rounds; + int dot_counter; + int bit_counter; public: typedef T share_type; diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 1a8a66b99..2d9eba572 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -20,7 +20,8 @@ template ProtocolBase::ProtocolBase() : - trunc_pr_counter(0), rounds(0), trunc_rounds(0), counter(0) + trunc_pr_counter(0), rounds(0), trunc_rounds(0), dot_counter(0), + bit_counter(0), counter(0) { } @@ -67,7 +68,11 @@ ProtocolBase::~ProtocolBase() { #ifdef VERBOSE_COUNT if (counter or rounds) - cerr << "Number of " << T::type_string() << " multiplications: " << counter << " in " << rounds << " rounds" << endl; + cerr << "Number of " << T::type_string() << " multiplications: " + << counter << " (" << bit_counter << " bits) in " << rounds + << " rounds" << endl; + if (counter or rounds) + cerr << "Number of " << T::type_string() << " dot products: " << dot_counter << endl; if (trunc_pr_counter or trunc_rounds) cerr << "Number of probabilistic truncations: " << trunc_pr_counter << " in " << trunc_rounds << " rounds" << endl; #endif @@ -126,6 +131,7 @@ template T ProtocolBase::finalize_dotprod(int length) { counter += length; + dot_counter++; T res; for (int i = 0; i < length; i++) res += finalize_mul(); @@ -199,6 +205,7 @@ template inline T Replicated::finalize_mul(int n) { this->counter++; + this->bit_counter += n; T result; result[0] = add_shares.next(); result[1].unpack(os[1], n); @@ -230,6 +237,7 @@ template inline T Replicated::finalize_dotprod(int length) { (void) length; + this->dot_counter++; return finalize_mul(); } @@ -316,6 +324,7 @@ void Replicated::trunc_pr(const vector& regs, int size, U& proc, for (auto info : infos) for (int i = 0; i < size; i++) { + this->trunc_pr_counter++; auto c_prime = input.finalize(comp_player); auto r_prime = input.finalize(gen_player); S[info.dest_base + i] = c_prime - r_prime; diff --git a/Protocols/ReplicatedMC.h b/Protocols/ReplicatedMC.h index bb6f36a20..17916a2e9 100644 --- a/Protocols/ReplicatedMC.h +++ b/Protocols/ReplicatedMC.h @@ -22,11 +22,11 @@ class ReplicatedMC : public MAC_Check_Base public: // emulate MAC_Check - ReplicatedMC(const typename T::value_type& _ = {}, int __ = 0, int ___ = 0) + ReplicatedMC(const typename T::mac_key_type& _ = {}, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; } // emulate Direct_MAC_Check - ReplicatedMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) + ReplicatedMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; (void)____; } void POpen(vector& values,const vector& S,const Player& P); @@ -34,7 +34,7 @@ class ReplicatedMC : public MAC_Check_Base void POpen_End(vector& values,const vector& S,const Player& P); virtual void exchange(const Player& P); - virtual typename T::open_type finalize_open(); + virtual typename T::open_type finalize_raw(); void Check(const Player& P) { (void)P; } diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index bbcf73e58..e72c0d839 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -65,7 +65,7 @@ void ReplicatedMC::finalize(vector& values, } template -typename T::open_type ReplicatedMC::finalize_open() +typename T::open_type ReplicatedMC::finalize_raw() { auto a = this->secrets.next().sum(); return a + o.get(); diff --git a/Protocols/SPDZ.h b/Protocols/SPDZ.h index fb2888c05..bd804ea0a 100644 --- a/Protocols/SPDZ.h +++ b/Protocols/SPDZ.h @@ -26,7 +26,8 @@ class SPDZ : public Beaver { } - static void assign(typename T::clear& share, const typename T::clear& clear, int my_num) + static void assign(typename T::open_type& share, + const typename T::open_type& clear, int my_num) { if (my_num == 0) share = clear; diff --git a/Protocols/SPDZ2k.h b/Protocols/SPDZ2k.h new file mode 100644 index 000000000..da128fee8 --- /dev/null +++ b/Protocols/SPDZ2k.h @@ -0,0 +1,28 @@ +/* + * SPDZ2k.h + * + */ + +#ifndef PROTOCOLS_SPDZ2K_H_ +#define PROTOCOLS_SPDZ2K_H_ + +#include "SPDZ.h" + +template +class SPDZ2k : public SPDZ +{ +public: + SPDZ2k(Player& P) : + SPDZ(P) + { + } + + void exchange() + { + for (size_t i = 0; i < this->shares.size(); i++) + this->MC->set_random_element({}); + SPDZ::exchange(); + } +}; + +#endif /* PROTOCOLS_SPDZ2K_H_ */ diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index ee5e83202..cc41d0236 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -21,8 +21,6 @@ class Semi2kShare : public SemiShare> typedef SignedZ2 T; public: - typedef Z2<64> mac_key_type; - typedef SemiMC MAC_Check; typedef DirectSemiMC Direct_MC; typedef SemiInput Input; diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 4fc265b7c..c40d0c170 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -18,7 +18,7 @@ class SemiInput : public InputBase { vector send_prngs; vector recv_prngs; - Player& P; + PlayerBase& P; vector> shares; public: @@ -27,7 +27,7 @@ class SemiInput : public InputBase { } - SemiInput(SubProcessor* proc, Player& P); + SemiInput(SubProcessor* proc, PlayerBase& P); SemiInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : SemiInput(0, P) diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index 3ed1feefe..f0fefe137 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -11,7 +11,7 @@ #include "ShamirInput.hpp" template -SemiInput::SemiInput(SubProcessor* proc, Player& P) : +SemiInput::SemiInput(SubProcessor* proc, PlayerBase& P) : InputBase(proc), P(P) { shares.resize(P.num_players()); diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index c2dd90858..b306d5c3d 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -9,6 +9,7 @@ #include "Protocols/Beaver.h" #include "Protocols/Semi.h" #include "Processor/DummyProtocol.h" +#include "GC/NoShare.h" #include "ShareInterface.h" #include @@ -51,8 +52,6 @@ class SemiShare : public T, public ShareInterface typedef T super; public: - typedef T mac_key_type; - typedef T mac_type; typedef T open_type; typedef T clear; @@ -87,10 +86,10 @@ class SemiShare : public T, public ShareInterface return nplayers - 1; } - static SemiShare constant(const clear& other, int my_num, - const T& alphai = {}, int = -1) + static SemiShare constant(const open_type& other, int my_num, + mac_key_type = {}, int = -1) { - return SemiShare(other, my_num, alphai); + return SemiShare(other, my_num); } SemiShare() @@ -100,7 +99,7 @@ class SemiShare : public T, public ShareInterface SemiShare(const U& other) : T(other) { } - SemiShare(const clear& other, int my_num, const T& alphai = {}) + SemiShare(const open_type& other, int my_num, const T& alphai = {}) { (void) alphai; Protocol::assign(*this, other, my_num); diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 6bda92dfc..c6a88f0ad 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -69,7 +69,7 @@ class ShamirMC : public IndirectShamirMC virtual void init_open(const Player& P, int n = 0); virtual void prepare_open(const T& secret); virtual void exchange(const Player& P); - virtual typename T::open_type finalize_open(); + virtual typename T::open_type finalize_raw(); void Check(const Player& P) { (void)P; } diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index e3e7cd3ac..7238aa5ef 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -112,11 +112,11 @@ void ShamirMC::finalize(vector& values, { values.clear(); for (size_t i = 0; i < S.size(); i++) - values.push_back(finalize_open()); + values.push_back(finalize_raw()); } template -typename T::open_type ShamirMC::finalize_open() +typename T::open_type ShamirMC::finalize_raw() { assert(reconstruction.size()); typename T::open_type res; diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index e7daabfcf..aea0bb97a 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -29,10 +29,7 @@ class ShamirShare : public T, public ShareInterface public: typedef T clear; typedef T open_type; - typedef T mac_key_type; typedef void sacri_type; - typedef GC::NoShare mac_type; - typedef GC::NoShare mac_share_type; typedef Shamir Protocol; typedef IndirectShamirMC MAC_Check; @@ -76,9 +73,9 @@ class ShamirShare : public T, public ShareInterface return Protocol::get_rec_factor(i, n); } - static ShamirShare constant(T value, int my_num, const T& alphai = {}) + static ShamirShare constant(T value, int, const mac_key_type& = {}) { - return ShamirShare(value, my_num, alphai); + return ShamirShare(value); } ShamirShare() @@ -89,42 +86,12 @@ class ShamirShare : public T, public ShareInterface { T::operator=(other); } - template - ShamirShare(const U& other, int my_num, T alphai = {}) : ShamirShare(other) - { - (void) my_num, (void) alphai; - } - // Share compatibility - void assign(clear other, int my_num, const T& alphai) - { - (void)alphai, (void)my_num; - *this = other; - } void assign(const char* buffer) { T::assign(buffer); } - void add(const ShamirShare& S, const clear aa, int my_num, - const T& alphai) - { - (void) my_num, (void) alphai; - *this = S + aa; - } - void sub(const ShamirShare& S, const clear& aa, int my_num, - const T& alphai) - { - (void) my_num, (void) alphai; - *this = S - aa; - } - void sub(const clear& aa, const ShamirShare& S, int my_num, - const T& alphai) - { - (void) my_num, (void) alphai; - *this = aa - S; - } - ShamirShare operator<<(int i) { return *this * (T(1) << i); diff --git a/Protocols/Share.h b/Protocols/Share.h index 92be4f144..e2a9f0bb5 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -73,7 +73,7 @@ class Share_ : public ShareInterface static void specification(octetStream& os) { T::specification(os); } - static Share_ constant(const clear& aa, int my_num, const typename V::Scalar& alphai) + static Share_ constant(const open_type& aa, int my_num, const typename V::Scalar& alphai) { return Share_(aa, my_num, alphai); } template @@ -85,12 +85,12 @@ class Share_ : public ShareInterface { a.assign_zero(); mac.assign_zero(); } - void assign(const clear& aa, int my_num, const typename V::Scalar& alphai); + void assign(const open_type& aa, int my_num, const typename V::Scalar& alphai); Share_() { assign_zero(); } template Share_(const Share_& S) { assign(S); } - Share_(const clear& aa, int my_num, const typename V::Scalar& alphai) + Share_(const open_type& aa, int my_num, const typename V::Scalar& alphai) { assign(aa, my_num, alphai); } Share_(const T& share, const V& mac) : a(share), mac(mac) {} @@ -128,6 +128,8 @@ class Share_ : public ShareInterface void force_to_bit() { a.force_to_bit(); } + void randomize(PRNG& G); + // Input and output from a stream // - Can do in human or machine only format (later should be faster) void output(ostream& s,bool human) const @@ -235,7 +237,7 @@ inline void Share_::mul(const Share_& S,const clear& aa) } template -inline void Share_::assign(const clear& aa, int my_num, +inline void Share_::assign(const open_type& aa, int my_num, const typename V::Scalar& alphai) { a = T::constant(aa, my_num); diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index 62f90e91d..c6f675f7f 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -23,6 +23,13 @@ void Share_::read_or_generate_mac_key(string directory, const Player& P, } } +template +void Share_::randomize(PRNG& G) +{ + a.randomize(G); + mac.randomize(G); +} + template inline void Share_::pack(octetStream& os, bool full) const { diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 444214e47..e5af8dddd 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -20,6 +20,7 @@ class ValueInterface; namespace GC { class NoShare; +class NoValue; } class ShareInterface @@ -28,6 +29,10 @@ class ShareInterface typedef GC::NoShare part_type; typedef GC::NoShare bit_type; + typedef GC::NoValue mac_key_type; + typedef GC::NoShare mac_type; + typedef GC::NoShare mac_share_type; + static const bool needs_ot = false; static const bool expensive = false; static const bool expensive_triples = false; diff --git a/Protocols/ShareMatrix.h b/Protocols/ShareMatrix.h index b7fdf50be..7f84213e6 100644 --- a/Protocols/ShareMatrix.h +++ b/Protocols/ShareMatrix.h @@ -128,7 +128,7 @@ class ShareMatrix : public ValueMatrix, public ShareInterface typedef ValueMatrix clear; typedef clear open_type; - typedef typename T::clear mac_key_type; + typedef typename T::mac_key_type mac_key_type; static string type_string() { diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 81e859319..4d03dd67e 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -14,7 +14,7 @@ inline ShuffleSacrifice::ShuffleSacrifice() : - B(OnlineOptions::singleton.bucket_size), C(this->B) + ShuffleSacrifice(OnlineOptions::singleton.bucket_size, 3) { } @@ -22,6 +22,9 @@ inline ShuffleSacrifice::ShuffleSacrifice(int B, int C) : B(B), C(C) { + if (OnlineOptions::singleton.security_parameter > 40) + throw runtime_error("shuffle sacrifice not implemented for more than " + "40-bit security"); } template diff --git a/Protocols/SohoPrep.hpp b/Protocols/SohoPrep.hpp index 48deeadc4..1dfd3ecb9 100644 --- a/Protocols/SohoPrep.hpp +++ b/Protocols/SohoPrep.hpp @@ -21,6 +21,7 @@ void SohoPrep::basic_setup(Player& P) assert(not setup); setup = new PartSetup; MachineBase machine; + setup->params.set_sec(OnlineOptions::singleton.security_parameter); setup->secure_init(P, machine, T::clear::length(), 0); read_or_generate_secrets(*setup, P, machine, 1, true_type()); T::clear::template init(); diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 03a91ff25..d95a713d1 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -8,7 +8,8 @@ #include "MascotPrep.h" #include "RingOnlyPrep.h" -#include "Spdz2kShare.h" + +template class Spdz2kShare; template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep); diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index 401070f84..762cd34df 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -18,6 +18,7 @@ template class Spdz2kMultiplier; template class Spdz2kTripleGenerator; +template class SPDZ2k; namespace GC { @@ -48,7 +49,7 @@ class Spdz2kShare : public Share> typedef MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; + typedef SPDZ2k Protocol; typedef Spdz2kPrep LivePrep; #ifndef NO_MIXED_CIRCUITS diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index 2ea08ba46..b7a8c741f 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -5,6 +5,8 @@ #include "SpdzWise.h" +#include "mac_key.hpp" + template SpdzWise::SpdzWise(Player& P) : internal(P), internal2(P), P(P) @@ -142,6 +144,7 @@ template void SpdzWise::zero_check(check_type t) { assert(T::clear::invertible); + check_field_size(); auto r = internal.get_random(); internal.init_mul(); internal.prepare_mul(t, r); diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h index 9991dafb2..9ad761985 100644 --- a/Protocols/SpdzWiseMC.h +++ b/Protocols/SpdzWiseMC.h @@ -44,9 +44,9 @@ class SpdzWiseMC : public MAC_Check_Base { inner_MC.exchange(P); } - typename T::open_type finalize_open() + typename T::open_type finalize_raw() { - return inner_MC.finalize_open(); + return inner_MC.finalize_raw(); } void Check(const Player& P) { diff --git a/Protocols/SpdzWiseShare.hpp b/Protocols/SpdzWiseShare.hpp index 6401c083a..acd6fa298 100644 --- a/Protocols/SpdzWiseShare.hpp +++ b/Protocols/SpdzWiseShare.hpp @@ -13,14 +13,15 @@ template void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& mac_key) { + bool fresh = false; + try { read_mac_key(directory, P.N, mac_key); } catch (mac_key_error&) { - SeededPRNG G; - mac_key.randomize(G); + fresh = true; } try @@ -33,11 +34,12 @@ void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& } catch (mac_fail&) { -#ifdef VERBOSE - cerr << "Generating fresh MAC key for " << type_string() << endl; -#endif - mac_key = typename T::Honest::Protocol(P).get_random(); + fresh = true; + cerr << "Invalid " << type_string() << " MAC key, generating fresh one" << endl; } + + if (fresh) + mac_key = typename T::Honest::Protocol(P).get_random(); } template diff --git a/Protocols/TemiPrep.h b/Protocols/TemiPrep.h index de7406bba..ad12837a8 100644 --- a/Protocols/TemiPrep.h +++ b/Protocols/TemiPrep.h @@ -64,6 +64,8 @@ class TemiPrep : public SemiHonestRingPrep { } + ~TemiPrep(); + void buffer_triples(); vector*>& get_multipliers(); diff --git a/Protocols/TemiPrep.hpp b/Protocols/TemiPrep.hpp index 1088a99cc..1f2a415cc 100644 --- a/Protocols/TemiPrep.hpp +++ b/Protocols/TemiPrep.hpp @@ -24,6 +24,7 @@ void TemiPrep::basic_setup(Player& P) assert(not setup); setup = new TemiSetup; MachineBase machine; + setup->params.set_sec(OnlineOptions::singleton.security_parameter); setup->secure_init(P, T::clear::length()); read_or_generate_secrets(*setup, P, machine, 1, true_type()); T::clear::template init(); @@ -104,6 +105,13 @@ TemiMultiplier::TemiMultiplier(Player& P) : P(P) { } +template +TemiPrep::~TemiPrep() +{ + for (auto& x : multipliers) + delete x; +} + template vector& TemiMultiplier::get_multiplicands( vector >& ciphertexts, const FHE_PK& pk) diff --git a/Protocols/config.h b/Protocols/config.h new file mode 100644 index 000000000..f88c3aa8a --- /dev/null +++ b/Protocols/config.h @@ -0,0 +1,13 @@ +/* + * config.h + * + */ + +#ifndef PROTOCOLS_CONFIG_H_ +#define PROTOCOLS_CONFIG_H_ + +#ifndef DEFAULT_SECURITY +#define DEFAULT_SECURITY 40 +#endif + +#endif /* PROTOCOLS_CONFIG_H_ */ diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index d15581ebd..dc9414278 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -11,7 +11,7 @@ using namespace std; template void check_share(vector& Sa, typename T::clear& value, - typename T::value_type& mac, int N, const typename T::value_type& key); + typename T::mac_type& mac, int N, const typename T::mac_key_type& key); template class Share; diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index aeb516118..bae415c4a 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -20,6 +20,7 @@ template class FixedVec; template class Share_; template class SpdzWiseShare; template class MaliciousRep3Share; +template class DealerShare; namespace GC { @@ -115,7 +116,6 @@ void make_share(GC::TinierSecret* Sa, const U& a, int N, const V& key, PRNG& template void make_share(SemiShare* Sa,const T& a,int N,const U&,PRNG& G) { - insecure("share generation", false); T x, S = a; for (int i=0; i* Sa,const T& a,int N,const U&,PRNG& G) Sa[N-1]=S; } +template +void make_share(DealerShare* Sa, const T& a, int N, const U&, PRNG& G) +{ + make_share((SemiShare*) Sa, a, N - 1, U(), G); + Sa[N - 1] = {}; +} + template void make_share(FixedVec* Sa, const V& a, int N, const U& key, PRNG& G); @@ -234,7 +241,7 @@ void check_share(vector >& Sa, template void check_share(vector& Sa, typename T::clear& value, - typename T::value_type& mac, int N, const typename T::value_type& key) + typename T::mac_type& mac, int N, const typename T::mac_key_type& key) { assert(N == 3); value = 0; @@ -340,23 +347,27 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P, { if (directory == "") directory = get_prep_sub_dir(P.num_players()); - typename T::mac_key_type res; + typename T::mac_key_type res, tmp; try { - read_mac_key(directory, P.my_num(), P.num_players(), res); + read_mac_key(directory, P.my_num(), P.num_players(), tmp); } catch (mac_key_error&) { - T::read_or_generate_mac_key(directory, P, res); - write_mac_key(directory, P.my_num(), P.num_players(), res); } + T::read_or_generate_mac_key(directory, P, res); + + // only write if changed + if (tmp != res) + write_mac_key(directory, P.my_num(), P.num_players(), res); + return res; } template -void read_global_mac_key(const string& directory, int nparties, U& key) +void read_global_mac_key(const string& directory, int nparties, U& key, false_type) { U pp; key.assign_zero(); @@ -372,6 +383,17 @@ void read_global_mac_key(const string& directory, int nparties, U& key) cout << "Final Keys : " << key << endl; } +template +void read_global_mac_key(const string&, int, U&, true_type) +{ +} + +template +void read_global_mac_key(const string& directory, int nparties, U& key) +{ + read_global_mac_key(directory, nparties, key, is_same()); +} + template T reconstruct(vector& shares) { @@ -548,4 +570,25 @@ void make_inverse(const typename T::mac_type& key, int N, int ntrip, bool zero, check_files(files.outf, N); } +template +void plain_edabits(vector& as, + vector& bs, int length, PRNG& G, + bool zero = false) +{ + int max_size = edabitvec::MAX_SIZE; + as.resize(max_size); + bs.clear(); + bs.resize(length); + bigint value; + for (int j = 0; j < max_size; j++) + { + if (not zero) + G.get_bigint(value, length, true); + as[j] = value; + for (int k = 0; k < length; k++) + bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; + } + +} + #endif diff --git a/Protocols/mac_key.hpp b/Protocols/mac_key.hpp index 843e6d47f..3f1000b89 100644 --- a/Protocols/mac_key.hpp +++ b/Protocols/mac_key.hpp @@ -22,4 +22,12 @@ typename T::mac_key_type read_or_generate_mac_key(const Player& P, return res; } +template +void check_field_size() +{ + if (T::length() < OnlineOptions::singleton.security_parameter) + throw runtime_error("Field too small for chosen security. " + "Increase size with -lgp or decrease security with -S"); +} + #endif /* PROTOCOLS_MAC_KEY_HPP_ */ diff --git a/README.md b/README.md index 96aa0fd53..d44190e10 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ parties and malicious security. On Linux, this requires a working toolchain and [all requirements](#requirements). On Ubuntu, the following might suffice: ``` -apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm +sudo apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm ``` On MacOS, this requires [brew](https://brew.sh) to be installed, which will be used for all dependencies. @@ -103,6 +103,7 @@ The following table lists all protocols that are fully supported. | Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +| Semi-honest, dealer | [Dealer](#dealer-model) | [Dealer](#dealer-model) | [Dealer](#dealer-model) | N/A | Modulo prime and modulo 2^k are the two settings that allow integer-like computation. For k = 64, the latter corresponds to the @@ -174,6 +175,9 @@ there are a few things to consider: adding `program.use_trunc_pr = True` at the beginning of your high-level program. +- Larger number of parties: ATLAS scales better than the plain Shamir + protocol, and Temi scale better than Hemi or Semi. + - Minor variants: Some command-line options change aspects of the protocols such as: @@ -771,7 +775,23 @@ the number of parties with `-N` and the maximum number of corrupted parties with `-T`. The latter can be at most half the number of parties. -### BMR +## Dealer model + +This security model defines a special party that generates correlated +randomness such as multiplication triples, which is then used by all +other parties. MP-SPDZ implements the canonical protocol where the +other parties run the online phase of the semi-honest protocol in +Semi(2k/Bin) and the dealer provides all preprocessing. The security +assumption is that dealer doesn't collude with any other party, but +all but one of the other parties are allowed to collude. In our +implementation, the dealer is the party with the highest number, so +with three parties overall, Party 0 and 1 run the online phase. + +| Program | Sharing | Domain | Malicious | \# parties | Script | +| --- | --- | --- | --- | --- | --- | +| `dealer-ring-party.x` | Additive | Mod 2^k | N | 3+ | `dealer-ring.sh` | + +## BMR BMR (Bellare-Micali-Rogaway) is a method of generating a garbled circuit using another secure computation protocol. We have implemented BMR diff --git a/Scripts/dealer-ring.sh b/Scripts/dealer-ring.sh new file mode 100755 index 000000000..b0d6692a5 --- /dev/null +++ b/Scripts/dealer-ring.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=${PLAYERS:-3} + +. $HERE/run-common.sh + +run_player dealer-ring-party.x $* || exit 1 diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py index 15959ee68..eaec677fd 100755 --- a/Scripts/memory-usage.py +++ b/Scripts/memory-usage.py @@ -12,7 +12,7 @@ print('Usage: %s ' % sys.argv[0]) res = collections.defaultdict(lambda: 0) -m = 0 +regs = collections.defaultdict(lambda: 0) for tapename in Program.read_tapes(sys.argv[1]): for inst in Tape.read_instructions(tapename): @@ -22,8 +22,9 @@ res[t.arg_format[0]]) for arg in inst.args: if isinstance(arg, RegisterArgFormat): - m = max(m, arg.i + inst.size) + regs[type(arg)] = max(regs[type(arg)], arg.i + inst.size) -print (res) -print (m) +reverse_formats = dict((v, k) for k, v in ArgFormats.items()) +print ('Memory:', dict(res)) +print ('Registers:', dict((reverse_formats[t], n) for t, n in regs.items())) diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index e8c02f6cb..3771383b7 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -52,7 +52,7 @@ for dabit in ${dabit:-0 1 2}; do ./compile.py -R 64 $compile_opts tutorial for i in ring rep4-ring semi2k brain mal-rep-ring ps-rep-ring sy-rep-ring \ - spdz2k; do + spdz2k dealer-ring; do test_vm $i $run_opts done @@ -65,7 +65,7 @@ for dabit in ${dabit:-0 1 2}; do done for i in cowgear chaigear; do - test_vm $i $run_opts -l 3 -c 2 + test_vm $i $run_opts -S 3 -c 2 done done @@ -83,7 +83,7 @@ fi ./compile.py tutorial for i in cowgear chaigear; do - test_vm $i $run_opts -l 3 -c 2 -J + test_vm $i $run_opts -S 3 -c 2 -J done if test $skip_binary; then diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index ed6c01441..ce906a3a3 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -27,7 +27,11 @@ if test "$flags"; then cpu=amd64 fi - cp -av bin/`uname`-$cpu/* . || { echo This only works with a release downloaded from https://github.com/data61/MP-SPDZ/releases 1>&2; exit 1; } + if ! cp -av bin/`uname`-$cpu/* .; then + echo This only works with a release downloaded from https://github.com/data61/MP-SPDZ/releases 1>&2 + echo Make sure NOT to download a source code only file 1>&2 + exit 1 + fi fi mkdir Player-Data 2> /dev/null diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 9dd15804c..f3e67c82f 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -85,6 +85,11 @@ void BufferBase::try_rewind() void BufferBase::prune() { + // only prune in secure mode +#ifdef INSECURE + return; +#endif + if (is_pipe()) return; diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index f6f4ba2ec..c7f8c371f 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -83,3 +83,8 @@ not_enough_to_buffer::not_enough_to_buffer(const string& type, const string& fil "adding -DINSECURE to the compiler options.") { } + +gf2n_not_supported::gf2n_not_supported(int n) : + runtime_error("GF(2^" + to_string(n) + ") not supported") +{ +} diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index fff8b2de4..bb347c6a8 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -278,4 +278,10 @@ class insufficient_memory : public runtime_error insufficient_memory(size_t size, const string& type); }; +class gf2n_not_supported : public runtime_error +{ +public: + gf2n_not_supported(int n); +}; + #endif diff --git a/Tools/avx_memcpy.h b/Tools/avx_memcpy.h index 231dc99cf..a00a215e2 100644 --- a/Tools/avx_memcpy.h +++ b/Tools/avx_memcpy.h @@ -20,6 +20,7 @@ template inline void avx_memcpy(void* dest, const void* source) { size_t length = L; +#ifdef __SSE2__ __m256i* d = (__m256i*)dest, *s = (__m256i*)source; #ifdef __AVX__ while (length >= 32) @@ -35,6 +36,10 @@ inline void avx_memcpy(void* dest, const void* source) _mm_storeu_si128(d2++, _mm_loadu_si128(s2++)); length -= 16; } +#else + void* d2 = dest; + const void* s2 = source; +#endif switch (length) { case 0: @@ -53,14 +58,16 @@ inline void avx_memcpy(void* dest, const void* source) inline void avx_memzero(void* dest, size_t length) { - __m256i* d = (__m256i*)dest; #ifdef __AVX__ + __m256i* d = (__m256i*)dest; __m256i s = _mm256_setzero_si256(); while (length >= 32) { _mm256_storeu_si256(d++, s); length -= 32; } +#else + void* d = dest; #endif switch (length) { diff --git a/Tools/benchmarking.cpp b/Tools/benchmarking.cpp index e956f15ec..88eee709f 100644 --- a/Tools/benchmarking.cpp +++ b/Tools/benchmarking.cpp @@ -5,6 +5,25 @@ #include "benchmarking.h" +void insecure(string message, bool warning) +{ +#ifdef INSECURE + if (warning) + cerr << "WARNING: insecure " << message << endl; +#else + (void)warning; + string msg = "You are trying to use insecure benchmarking functionality for " + + message + ".\nYou can activate this at compile time " + "by adding -DINSECURE to the compiler options.\n" + "Make sure to run 'make clean' as well before compiling."; + cerr << msg << endl; +#ifdef INSECURE_EXCEPTION + throw exception(); +#endif + exit(1); +#endif +} + void insecure_fake() { #if defined(INSECURE) or defined(INSECURE_FAKE) diff --git a/Tools/benchmarking.h b/Tools/benchmarking.h index 13fa9c365..e54990ca7 100644 --- a/Tools/benchmarking.h +++ b/Tools/benchmarking.h @@ -12,20 +12,7 @@ using namespace std; // call before insecure benchmarking functionality -inline void insecure(string message, bool warning = true) -{ -#ifdef INSECURE - if (warning) - cerr << "WARNING: insecure " << message << endl; -#else - (void)warning; - string msg = "You are trying to use insecure benchmarking functionality for " - + message + ".\nYou can activate this at compile time " - "by adding -DINSECURE to the compiler options.\n" - "Make sure to run make clean as well."; - throw runtime_error(msg); -#endif -} +void insecure(string message, bool warning = true); void insecure_fake(); diff --git a/Tools/intrinsics.h b/Tools/intrinsics.h index 45664a72b..e7cb87dc5 100644 --- a/Tools/intrinsics.h +++ b/Tools/intrinsics.h @@ -10,6 +10,7 @@ #include #include #else +#ifdef __aarch64__ #define SIMDE_X86_AVX_ENABLE_NATIVE_ALIASES #define SIMDE_X86_AVX2_ENABLE_NATIVE_ALIASES #define SIMDE_X86_SSE2_ENABLE_NATIVE_ALIASES @@ -18,5 +19,6 @@ #include "simde/simde/x86/clmul.h" #include "aes-arm.h" #endif +#endif #endif /* TOOLS_INTRINSICS_H_ */ diff --git a/Tools/parse.h b/Tools/parse.h index af5e6de4c..8ff0fee9e 100644 --- a/Tools/parse.h +++ b/Tools/parse.h @@ -23,6 +23,14 @@ inline int get_int(istream& s) return be32toh(n); } +// Read an 8-byte integer +inline int64_t get_long(istream& s) +{ + int64_t n; + s.read((char*) &n, 8); + return be64toh(n); +} + // Read several integers inline void get_ints(int* res, istream& s, int count) { diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index 3a3644144..203328f22 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -36,7 +36,8 @@ string PREP_DATA_PREFIX; template void check_mult_triples(const typename T::mac_key_type& key,int N,vector*>& dataF) { - typename T::clear a,b,c,mac; + typename T::clear a,b,c; + typename T::mac_type mac; vector Sa(N),Sb(N),Sc(N); int n = 0; @@ -99,7 +100,8 @@ void check_tuple(const T& a, const T& b, int n, Dtype type) template void check_tuples(const typename T::mac_key_type& key,int N,vector*>& dataF, Dtype type) { - typename T::clear a,b,c,mac,res; + typename T::clear a,b,c,res; + typename T::mac_type mac; vector Sa(N),Sb(N),Sc(N); int n = 0; @@ -127,7 +129,8 @@ void check_tuples(const typename T::mac_key_type& key,int N,vector void check_bits(const typename T::mac_key_type& key,int N,vector*>& dataF) { - typename T::clear a,b,c,mac,res; + typename T::clear a,b,c,res; + typename T::mac_type mac; vector Sa(N),Sb(N),Sc(N); int n = 0; @@ -157,7 +160,8 @@ void check_bits(const typename T::mac_key_type& key,int N,vector void check_inputs(const typename T::mac_key_type& key,int N,vector*>& dataF) { - typename T::clear a, mac, x; + typename T::clear a, x; + typename T::mac_type mac; vector Sa(N); for (int player = 0; player < N; player++) diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index f1158cfa6..823c318bb 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -154,16 +154,9 @@ void FakeParams::make_edabits(const typename T::mac_type& key, int N, int ntrip, int max_size = edabitvec::MAX_SIZE; for (int i = 0; i < ntrip / max_size; i++) { - vector as(max_size); - vector bs(length); - for (int j = 0; j < max_size; j++) - { - if (not zero) - G.get_bigint(value, length, true); - as[j] = value; - for (int k = 0; k < length; k++) - bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; - } + vector as; + vector bs; + plain_edabits(as, bs, length, G, zero); for (auto& a : as) files.template output_shares(a); for (auto& b : bs) @@ -737,9 +730,12 @@ int FakeParams::generate() if (nplayers == 3) { make_bits>({}, nplayers, nbitsp, zero); - make_basic>({}, nplayers, default_num, zero); - make_basic>({}, nplayers, default_num, zero); - make_with_mac_key>(nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, + zero); + make_basic>({}, nplayers, + default_num, zero); + make_with_mac_key>(nplayers, + default_num, zero); make_mult_triples({}, nplayers, ntrip2, zero, prep_data_prefix); make_bits({}, nplayers, nbits2, zero); @@ -748,17 +744,21 @@ int FakeParams::generate() make_basic>({}, nplayers, default_num, zero); make_basic>>({}, nplayers, default_num, zero); + make_basic>>({}, nplayers, default_num, zero); + make_minimal({}, nplayers, default_num, zero); make_mult_triples({}, nplayers, default_num, zero, prep_data_prefix); make_bits({}, nplayers, default_num, zero); gf2n_short::reset(); - gf2n_short::init_field(40); + gf2n_short::init_field(); - Z2<41> keyt; - generate_mac_keys>(keyt, nplayers, prep_data_prefix); + Z2 keyt; + generate_mac_keys>(keyt, nplayers, + prep_data_prefix); - make_minimal>(keyt, nplayers, default_num / 64, zero); + make_minimal>(keyt, nplayers, + default_num / 64, zero); gf2n_short keytt; generate_mac_keys>(keytt, nplayers, prep_data_prefix); diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index 962b27753..d00acd2af 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -20,6 +20,7 @@ #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/SemiSecret.hpp" #include "Protocols/Atlas.hpp" #include "Protocols/MaliciousRepPrep.hpp" #include "Protocols/Share.hpp" diff --git a/Utils/l2h-example.cpp b/Utils/l2h-example.cpp new file mode 100644 index 000000000..475bcb8aa --- /dev/null +++ b/Utils/l2h-example.cpp @@ -0,0 +1,54 @@ +/* + * l2h-example.cpp + * + */ + +#include "Protocols/ProtocolSet.h" + +#include "Math/gfp.hpp" +#include "Machines/SPDZ.hpp" + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 2) + { + cerr << "Usage: " << argv[0] << " " << endl; + exit(1); + } + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + + // template parameters are share types for integer and GF(2^n) computation + Machine, Share> machine(N); + + // protocols to be used directly + ProtocolSet> set(machine.get_player(), machine.get_sint_mac_key()); + + // data to be used in steps + set.input.reset_all(machine.get_player()); + set.input.add_from_all(2 + my_number); + set.input.exchange(); + machine.Mp.MS.resize(n_parties); + for (int i = 0; i < n_parties; i++) + machine.Mp.MS[i] = set.input.finalize(i); + + machine.run_step("l2h_multiplication"); + machine.run_step("l2h_comparison"); + + // check results + // multiplication + assert(set.output.open(machine.Mp.MS[2], machine.get_player()) == 6); + // comparison + assert(set.output.open(machine.Mp.MS[3], machine.get_player()) == 1); + + set.check(); + + // print usage + auto res = machine.stop_threads(); + res.first.print_cost(); +} diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp index a36949d6f..6eda84e03 100644 --- a/Utils/mixed-example.cpp +++ b/Utils/mixed-example.cpp @@ -6,6 +6,7 @@ #include "Protocols/ProtocolSet.h" #include "Machines/SPDZ.hpp" +#include "Machines/SPDZ2k.hpp" #include "Machines/Semi2k.hpp" #include "Machines/Rep.hpp" #include "Machines/Rep4.hpp" diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 83571c218..e5346ade6 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -9,6 +9,7 @@ #include "Math/gfp.hpp" #include "Machines/SPDZ.hpp" +#include "Machines/SPDZ2k.hpp" #include "Machines/MalRep.hpp" #include "Machines/ShamirMachine.hpp" #include "Machines/Semi2k.hpp" @@ -30,7 +31,9 @@ int main(int argc, char** argv) // need player number and number of players if (argc < 3) { - cerr << "Usage: " << argv[0] << " [protocol [threshold]]" << endl; + cerr << "Usage: " << argv[0] + << " [protocol [threshold]]" + << endl; exit(1); } diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index 38cdc922e..62896b3a3 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -250,7 +250,7 @@ void YaoEvalWire::convcbit2s(GC::Processor& processor, for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) { auto& dest = processor.S[instruction.get_r(0) + i]; - dest.resize_regs(min(unsigned(unit), instruction.get_n() - i * unit)); + dest.resize_regs(min(size_t(unit), instruction.get_n() - i * unit)); for (auto& reg : dest.get_regs()) reg.set(0); } diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 05a8646db..e50628086 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -238,7 +238,7 @@ void YaoGarbleWire::convcbit2s(GC::Processor& processor, for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) { auto& dest = processor.S[instruction.get_r(0) + i]; - int n = min(unsigned(unit), instruction.get_n() - i * unit); + int n = min(size_t(unit), instruction.get_n() - i * unit); dest.resize_regs(n); for (int j = 0; j < n; j++) dest.get_reg(j).public_input( diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index a947b1f51..b1e0e0736 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -33,7 +33,7 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) "--threshold" // Flag token. ); auto& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv, false_type()}; + online_opts = {opt, argc, argv, false}; NetworkOptionsWithNumber network_opts(opt, argc, argv, 2, false); online_opts.finalize(opt, argc, argv); diff --git a/Yao/YaoWire.hpp b/Yao/YaoWire.hpp index aa04fe357..984db38f1 100644 --- a/Yao/YaoWire.hpp +++ b/Yao/YaoWire.hpp @@ -55,7 +55,7 @@ void YaoWire::andm(GC::Processor& processor, for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) { auto &dest = processor.S[instruction.get_r(0) + i]; - int n = min(unsigned(unit), instruction.get_n() - i * unit); + int n = min(size_t(unit), instruction.get_n() - i * unit); dest.resize_regs(n); for (int j = 0; j < n; j++) if (processor.C[instruction.get_r(2) + i].get_bit(j)) diff --git a/doc/Compiler.rst b/doc/Compiler.rst index df3e13f5c..db5c1e9c2 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -28,12 +28,16 @@ Compiler.GC.types module .. automodule:: Compiler.GC.types :members: :no-undoc-members: - :no-inherited-members: - :show-inheritance: + :inherited-members: :exclude-members: PreOp, cbit, dynamic_array, conv_cint_vec, bitdec, bit_type, bitcom, clear_type, conv_regint, default_type, mov, dyn_sbits, int_type, mul, vec, load_mem, - DynamicArray, get_raw_input_from + DynamicArray, get_raw_input_from, bits, + input_tensor_from, input_tensor_from_client, + input_tensor_via, dot_product, Matrix, Tensor, + from_sint, read_from_file, receive_from_client, + reveal_to_clients, write_shares_to_socket, + write_to_file Compiler.library module ----------------------- diff --git a/doc/non-linear.rst b/doc/non-linear.rst index e5df4c204..969e6d6c3 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -56,6 +56,10 @@ Power-of-two modulus mask-and-reveal approach above to the setting of computation modulo a power of two. +See also `this slide deck +`_ for an +introduction to non-linear computation in arithmetic MPC. + Mixed-Circuit Computation ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -70,6 +74,12 @@ more general methods such as `daBits `_ and `edaBits `_. +See also `this slide deck +`_ for an introduction to +mixed-circuit computation. + + +.. _protocol-pairs: Protocol Pairs ============== From 4e811ec1597edfbe2236d21dfdda65f4ce247413 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Wed, 20 Apr 2022 15:51:24 -0500 Subject: [PATCH 056/221] Fix tiny typo in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d44190e10..a3f6741fd 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ there are a few things to consider: - Computation domain: Arithmetic protocols (modulo prime or power of two) are preferable for many applications because they offer integer addition and multiplication at low cost. However, binary circuits - might a better option if there is very little integer + might be a better option if there is very little integer computation. [See below](#finding-the-most-efficient-variant) to find the most efficient mixed-circuit variant. Furthermore, local computation modulo a power of two is cheaper, but MP-SPDZ does not From a5917de3cf5143baaf89e2909384caaf50ddd921 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 21 Apr 2022 12:36:54 +0200 Subject: [PATCH 057/221] Protocol setup with exact modulus. --- Protocols/ProtocolSetup.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/Protocols/ProtocolSetup.h b/Protocols/ProtocolSetup.h index b6d91b2bc..5953b6663 100644 --- a/Protocols/ProtocolSetup.h +++ b/Protocols/ProtocolSetup.h @@ -35,6 +35,22 @@ class ProtocolSetup T::read_or_generate_mac_key(directory, P, mac_key); } + /** + * @param prime modulus for computation + * @param P communication instance (used for MAC generation if needed) + * @param directory location to read MAC if needed + */ + ProtocolSetup(bigint prime, Player& P, string directory = "") + { + static_assert(T::clear::prime_field, "must use computation modulo a prime"); + + T::clear::init_field(prime); + T::clear::next::init_field(prime, false); + + // must initialize MAC key for security of some protocols + T::read_or_generate_mac_key(directory, P, mac_key); + } + ~ProtocolSetup() { T::LivePrep::teardown(); From 2760659ad4cd740e659447a5854b933daaa139e9 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Thu, 21 Apr 2022 22:54:18 -0500 Subject: [PATCH 058/221] Fix comment/example in Dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index e01dd5c53..5d8c28888 100644 --- a/Dockerfile +++ b/Dockerfile @@ -132,7 +132,7 @@ RUN make clean && make ${machine} && cp ${machine} /usr/local/bin/ # --build-arg machine=replicated-ring-party.x \ # # --build-arg prep_dir=/opt/prep \ # # --build-arg ssl_dir=/opt/ssl \ # -# --build-arg nparties=3 \ # +# --build-arg cryptoplayers=3 \ # # --build-arg compile_options="--ring=64" . # # # # Test it: # From a858e5b440902ec25dbb97c555eef35e12fbf69c Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 21 Apr 2022 18:29:27 +0200 Subject: [PATCH 059/221] Security bug in homomorphic encryption parameter generation. --- FHE/NoiseBounds.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index a1fe3e033..e2df9583f 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -105,7 +105,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants() { tp *= t; double lgtp = log(tp) / log(2.0); - if (C[i] < 0 && lgtp < FHE_epsilon) + if (C[i] < 0 && lgtp < -FHE_epsilon) { C[i] = pow(x, i); } From db6763513425d45b64396f57604c7729e59ba556 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 4 May 2022 14:11:13 +0200 Subject: [PATCH 060/221] Security bug in Temi matrix multiplication. --- Protocols/HemiMatrixPrep.hpp | 38 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index b2dd92d21..3446733e5 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -53,12 +53,14 @@ class MatrixRandMultJob : public ThreadJob public: MatrixRandMultJob(vector>& C, const vector>& A, - vector>& B) + vector>& B, + bool local_mul) { type = MATRX_RAND_MULT_JOB; output = &C; input = &A; supply = &B; + length = local_mul; } }; @@ -73,7 +75,8 @@ inline void matrix_rand_mult(ThreadJob job, true_type = {}) { A[i].randomize(G); B[i].randomize(G); - C[i] = A[i] * B[i]; + if (job.length) + C[i] = A[i] * B[i]; } } @@ -101,25 +104,22 @@ void HemiMatrixPrep::buffer_triples() B(n_matrices, {n_inner, n_cols}); SeededPRNG G; AddableVector> C(n_matrices); - MatrixRandMultJob job(C, A, B); + MatrixRandMultJob job(C, A, B, T::local_mul); - if (T::local_mul) + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) { - if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) - { - auto& queues = BaseMachine::s().queues; - int start = queues.distribute(job, n_matrices); - job.begin = start; - job.end = n_matrices; - matrix_rand_mult(job); - queues.wrap_up(job); - } - else - { - job.begin = 0; - job.end = n_matrices; - matrix_rand_mult(job); - } + auto& queues = BaseMachine::s().queues; + int start = queues.distribute(job, n_matrices); + job.begin = start; + job.end = n_matrices; + matrix_rand_mult(job); + queues.wrap_up(job); + } + else + { + job.begin = 0; + job.end = n_matrices; + matrix_rand_mult(job); } #ifdef VERBOSE_HE From 642d11f7dd99f6a5f5298090def695aeee906223 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 4 May 2022 14:09:15 +0200 Subject: [PATCH 061/221] Compile-time option for unencrypted client connections. --- ExternalIO/Client.h | 56 +++++++++++++++++++++++++++++++++-- ExternalIO/Client.hpp | 10 +++++-- Networking/sockets.h | 7 ----- Networking/ssl_sockets.h | 2 -- Processor/ExternalClients.cpp | 6 ++-- Processor/ExternalClients.h | 7 +++-- 6 files changed, 68 insertions(+), 20 deletions(-) diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index de9e9cad4..4e1e4c4b0 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -8,20 +8,72 @@ #include "Networking/ssl_sockets.h" +#ifdef NO_CLIENT_TLS +class client_ctx +{ +public: + client_ctx(string) + { + } +}; + +class client_socket +{ +public: + int socket; + + client_socket(boost::asio::io_service&, + client_ctx&, int plaintext_socket, string, + string, bool) : socket(plaintext_socket) + { + } + + ~client_socket() + { + close(socket); + } +}; + +inline void send(client_socket* socket, octet* data, size_t len) +{ + send(socket->socket, data, len); +} + +inline void receive(client_socket* socket, octet* data, size_t len) +{ + receive(socket->socket, data, len); +} + +#else + +typedef ssl_ctx client_ctx; + +class client_socket : public ssl_socket +{ +public: + client_socket(boost::asio::io_service& io_service, + boost::asio::ssl::context& ctx, int plaintext_socket, string other, + string me, bool client) : + ssl_socket(io_service, ctx, plaintext_socket, other, me, client) + { + } +}; +#endif + /** * Client-side interface */ class Client { vector plain_sockets; - ssl_ctx ctx; + client_ctx ctx; ssl_service io_service; public: /** * Sockets for cleartext communication */ - vector sockets; + vector sockets; /** * Specification of computation domain diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index 3af40f2f4..ffc9705cc 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -20,7 +20,7 @@ Client::Client(const vector& hostnames, int port_base, { set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i); octetStream(to_string(my_client_id)).Send(plain_sockets[i]); - sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i], + sockets[i] = new client_socket(io_service, ctx, plain_sockets[i], "P" + to_string(i), "C" + to_string(my_client_id), true); if (i == 0) specification.Receive(sockets[0]); @@ -50,11 +50,15 @@ void Client::send_private_inputs(const vector& values) // Receive num_inputs triples from SPDZ for (size_t j = 0; j < sockets.size(); j++) { +#ifdef VERBOSE_COMM + cerr << "receiving from " << j << endl << flush; +#endif + os.reset_write_head(); os.Receive(sockets[j]); #ifdef VERBOSE_COMM - cerr << "received " << os.get_length() << " from " << j << endl; + cerr << "received " << os.get_length() << " from " << j << endl << flush; #endif for (int j = 0; j < num_inputs; j++) @@ -101,7 +105,7 @@ vector Client::receive_outputs(int n) os.reset_write_head(); os.Receive(socket); #ifdef VERBOSE_COMM - cout << "received " << os.get_length() << endl; + cout << "received " << os.get_length() << endl << flush; #endif for (int j = 0; j < 3 * n; j++) { diff --git a/Networking/sockets.h b/Networking/sockets.h index 7f48aad10..b67a20768 100644 --- a/Networking/sockets.h +++ b/Networking/sockets.h @@ -35,11 +35,6 @@ void send(T& socket, size_t a, size_t len); template void receive(T& socket, size_t& a, size_t len); -template -void send(T socket, octet* msg, size_t len); -template -void receive(T socket, octet* msg, size_t len); - inline size_t send_non_blocking(int socket, octet* msg, size_t len) { @@ -54,7 +49,6 @@ inline size_t send_non_blocking(int socket, octet* msg, size_t len) return j; } -template<> inline void send(int socket,octet *msg,size_t len) { size_t i = 0; @@ -72,7 +66,6 @@ inline void send(T& socket, size_t a, size_t len) send(socket, blen, len); } -template<> inline void receive(int socket,octet *msg,size_t len) { size_t i=0; diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index fe9477a81..816139953 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -87,7 +87,6 @@ inline size_t send_non_blocking(ssl_socket* socket, octet* data, size_t length) return socket->write_some(boost::asio::buffer(data, length)); } -template<> inline void send(ssl_socket* socket, octet* data, size_t length) { size_t sent = 0; @@ -103,7 +102,6 @@ inline void send(ssl_socket* socket, octet* data, size_t length) } } -template<> inline void receive(ssl_socket* socket, octet* data, size_t length) { size_t received = 0; diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp index 65bb4598c..48bb8bd17 100644 --- a/Processor/ExternalClients.cpp +++ b/Processor/ExternalClients.cpp @@ -51,8 +51,8 @@ int ExternalClients::get_client_connection(int portnum_base) client); client_id = stoi(client); if (ctx == 0) - ctx = new ssl_ctx("P" + to_string(get_party_num())); - external_client_sockets[client_id] = new ssl_socket(io_service, *ctx, socket, + ctx = new client_ctx("P" + to_string(get_party_num())); + external_client_sockets[client_id] = new client_socket(io_service, *ctx, socket, "C" + to_string(client_id), "P" + to_string(get_party_num()), false); client_ports[client_id] = portnum_base; cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl; @@ -75,7 +75,7 @@ int ExternalClients::get_party_num() return party_num; } -ssl_socket* ExternalClients::get_socket(int id) +client_socket* ExternalClients::get_socket(int id) { if (external_client_sockets.find(id) == external_client_sockets.end()) throw runtime_error("external connection not found for id " + to_string(id)); diff --git a/Processor/ExternalClients.h b/Processor/ExternalClients.h index e437f3168..5ea1b3fdc 100644 --- a/Processor/ExternalClients.h +++ b/Processor/ExternalClients.h @@ -4,6 +4,7 @@ #include "Networking/sockets.h" #include "Networking/ssl_sockets.h" #include "Tools/Exceptions.h" +#include "ExternalIO/Client.h" #include #include #include @@ -25,11 +26,11 @@ class ExternalClients int party_num; // Maps holding per client values (indexed by unique 32-bit id) - std::map external_client_sockets; + std::map external_client_sockets; std::map client_ports; ssl_service io_service; - ssl_ctx* ctx; + client_ctx* ctx; public: @@ -43,7 +44,7 @@ class ExternalClients void close_connection(int client_id); // return the socket for a given client or server identifier - ssl_socket* get_socket(int socket_id); + client_socket* get_socket(int socket_id); int get_party_num(); }; From b3c39c4d37947fe7ce136ea0d0de5f8effcd564f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 11 May 2022 16:41:27 +0200 Subject: [PATCH 062/221] Missing vectorization. --- Compiler/types.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Compiler/types.py b/Compiler/types.py index 99ca6a8c6..03fe7eb45 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1148,12 +1148,14 @@ def bit_decompose(self, bit_length=None): bit_length = bit_length or program.bit_length return floatingpoint.bits(self, bit_length) + @vectorize def legendre(self): """ Clear Legendre symbol computation. """ res = cint() legendrec(res, self) return res + @vectorize def digest(self, num_bytes): """ Clear hashing (libsodium default). """ res = cint() From 59fd44be22984218f774b251472cc583c30cffea Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 16 May 2022 15:25:33 +0100 Subject: [PATCH 063/221] Fix compilation with OpenSSL 3. --- ECDSA/P256Element.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 8437f39d2..2c8c776d2 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -29,7 +29,7 @@ P256Element::P256Element(const Scalar& other) : { BIGNUM* exp = BN_new(); BN_dec2bn(&exp, bigint(other).get_str().c_str()); - assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0); + assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0); BN_free(exp); } @@ -38,7 +38,7 @@ P256Element::P256Element(word other) : { BIGNUM* exp = BN_new(); BN_dec2bn(&exp, to_string(other).c_str()); - assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0); + assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0); BN_free(exp); } @@ -56,7 +56,11 @@ void P256Element::check() P256Element::Scalar P256Element::x() const { BIGNUM* x = BN_new(); +#if OPENSSL_VERSION_MAJOR >= 3 + assert(EC_POINT_get_affine_coordinates(curve, point, x, 0, 0) != 0); +#else assert(EC_POINT_get_affine_coordinates_GFp(curve, point, x, 0, 0) != 0); +#endif char* xx = BN_bn2dec(x); Scalar res((bigint(xx))); OPENSSL_free(xx); From 8e4fd45c17412dbbccabfb4dc796ed689b2f42e5 Mon Sep 17 00:00:00 2001 From: Jakob Zierk <48928791+jakobzierk@users.noreply.github.com> Date: Tue, 17 May 2022 09:17:02 +0200 Subject: [PATCH 064/221] Windows/VirtualBox performance Added workaround. --- doc/troubleshooting.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 6a79ea198..268084806 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -140,6 +140,29 @@ This indicates an error in the internal accounting of preprocessing. Please file a bug report. +Windows/VirtualBox performance +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Performance when using Windows/VirtualBox is by default abysmal, as +AVX/AVX2 instructions are deactivated (see e.g. +`here `_), +which causes a dramatic performance loss. Deactivate Hyper-V/Hypervisor +using:: + bcdedit /set hypervisorlaunchtype off + DISM /Online /Disable-Feature:Microsoft-Hyper-V + + +Performance can be further increased when compiling MP-SPDZ yourself: +:: + sudo apt-get update + sudo apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm + git clone https://github.com/data61/MP-SPDZ.git + cd MP-SPDZ + make tldr + +See also `this issue `_ for a discussion. + + ``mac_fail`` ~~~~~~~~~~~~ From de12e08784636497dc315f287d08513640b6e50d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 23 May 2022 17:27:56 +0200 Subject: [PATCH 065/221] Fix bugs on macOS. --- Math/Z2k.h | 1 + Math/gf2n.h | 1 + Math/gfp.h | 1 + Tools/parse.h | 1 + 4 files changed, 4 insertions(+) diff --git a/Math/Z2k.h b/Math/Z2k.h index 586c78c06..cdde3f40c 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -85,6 +85,7 @@ class Z2 : public ValueInterface Z2(__m128i x) : Z2() { avx_memcpy(a, &x, min(N_BYTES, 16)); } Z2(int x) : Z2(long(x)) { a[N_WORDS - 1] &= UPPER_MASK; } Z2(long x) : Z2(mp_limb_t(x)) { if (K > 64 and x < 0) memset(&a[1], -1, N_BYTES - 8); } + Z2(long long x) : Z2(long(x)) {} template Z2(const IntBase& x); /** diff --git a/Math/gf2n.h b/Math/gf2n.h index 485d84308..3ec8849af 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -118,6 +118,7 @@ class gf2n_ : public ValueInterface gf2n_(U a) : a(a & mask) {} gf2n_(long a) : gf2n_(U(a)) {} gf2n_(int a) : gf2n_(U(unsigned(a))) {} + gf2n_(long long a) : gf2n_(U(a)) {} template gf2n_(IntBase a) : a(a.get()) {} diff --git a/Math/gfp.h b/Math/gfp.h index 3bc23e194..9a50dc035 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -161,6 +161,7 @@ class gfp_ : public ValueInterface gfp_(const mpz_class& x) { to_modp(a, x, ZpD); } gfp_(int x) : gfp_(long(x)) {} gfp_(long x); + gfp_(long long x) : gfp_(long(x)) {} gfp_(word x) : gfp_(bigint::tmp = x) {} template gfp_(IntBase x) : gfp_(x.get()) {} diff --git a/Tools/parse.h b/Tools/parse.h index 8ff0fee9e..c4b973dd7 100644 --- a/Tools/parse.h +++ b/Tools/parse.h @@ -13,6 +13,7 @@ using namespace std; #ifdef __APPLE__ # include #define be32toh(x) OSSwapBigToHostInt32(x) +#define be64toh(x) OSSwapBigToHostInt64(x) #endif // Read a 4-byte integer From 1460c9b5748c7cb7779cb79b2d3de792106be1f2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 24 May 2022 15:54:56 +0200 Subject: [PATCH 066/221] Fix output issue. --- Compiler/types.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/Compiler/types.py b/Compiler/types.py index 03fe7eb45..1d06f3f71 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5189,10 +5189,10 @@ def delete(self): self.value_type.free(self.address) self.address = None - def get_address(self, index): + def get_address(self, index, size=None): if isinstance(index, (_secret, _single)): raise CompilerError('need cleartext index') - key = str(index) + key = str(index), size or 1 if self.length is not None: from .GC.types import cbits if isinstance(index, int): @@ -5211,6 +5211,8 @@ def get_address(self, index): # length can be None for single-element arrays length = 0 base = self.address + index * self.value_type.mem_size() + if size is not None and isinstance(base, _register): + base = regint._expand_address(base, size) self.address_cache[program.curr_block, key] = \ util.untuplify([base + i * length \ for i in range(n)]) @@ -5332,7 +5334,8 @@ def assign(self, other, base=0): except: pass try: - self.value_type.conv(other).store_in_mem(self.get_address(base)) + other = self.value_type.conv(other) + other.store_in_mem(self.get_address(base, other.size)) if len(self) != None and util.is_constant(base): assert len(self) >= other.size + base except (AttributeError, CompilerError): @@ -5370,7 +5373,7 @@ def get_vector(self, base=0, size=None): :param base: starting point (regint/cint/int) :param size: length (compile-time int) """ size = size or self.length - base - return self.value_type.load_mem(self.get_address(base), size=size) + return self.value_type.load_mem(self.get_address(base, size), size=size) get_part_vector = get_vector @@ -5581,6 +5584,9 @@ def Array(self, size): # compatibility with registers return Array(size, self.value_type) + def output_if(self, cond): + library.print_str_if(cond, '%s', self.get_vector()) + def __str__(self): return '%s array of length %s at %s' % (self.value_type, len(self), self.address) From 2dad77ba326e5266d6447c69824896d1b458c08f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 24 May 2022 16:54:30 +0200 Subject: [PATCH 067/221] More flexible conversion. --- Compiler/GC/types.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 6c3abad0a..b34e68c82 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -154,9 +154,15 @@ def load_other(self, other): self.set_length(self.n or util.int_len(other)) self.load_int(other) elif isinstance(other, regint): - assert(other.size == math.ceil(self.n / self.unit)) - for i, (x, y) in enumerate(zip(self, other)): + assert self.unit == 64 + n_units = int(math.ceil(self.n / self.unit)) + n_convs = min(other.size, n_units) + for i in range(n_convs): + x = self[i] + y = other[i] self.conv_regint(min(self.unit, self.n - i * self.unit), x, y) + for i in range(n_convs, n_units): + inst.ldbits(self[i], min(self.unit, self.n - i * self.unit), 0) elif (isinstance(self, type(other)) or isinstance(other, type(self))) \ and self.n == other.n: for i in range(math.ceil(self.n / self.unit)): From 5ab8c702dde2f25ae7f2f2d0e4d47f5d716fa621 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 27 May 2022 14:19:33 +0200 Subject: [PATCH 068/221] Secure shuffling. --- BMR/Party.cpp | 1 - BMR/RealProgramParty.hpp | 5 +- CHANGELOG.md | 9 + Compiler/GC/types.py | 5 +- Compiler/instructions.py | 65 +++++ Compiler/instructions_base.py | 5 + Compiler/oram.py | 142 ++++++----- Compiler/path_oram.py | 31 +-- Compiler/permutation.py | 203 ++-------------- Compiler/sorting.py | 54 +++++ Compiler/types.py | 139 ++++++++++- ExternalIO/Client.h | 11 +- FHE/AddableVector.h | 11 +- FHE/Ciphertext.cpp | 31 ++- FHE/Ciphertext.h | 25 +- FHE/Diagonalizer.cpp | 3 + FHE/FFT_Data.cpp | 5 + FHE/FFT_Data.h | 2 +- FHE/FHE_Keys.cpp | 64 ++++- FHE/FHE_Keys.h | 55 ++++- FHE/FHE_Params.cpp | 37 +++ FHE/FHE_Params.h | 24 ++ FHE/NTL-Subs.cpp | 31 ++- FHE/NoiseBounds.cpp | 6 +- FHE/NoiseBounds.h | 2 + FHE/P2Data.cpp | 8 +- FHE/Plaintext.cpp | 100 ++++---- FHE/Plaintext.h | 45 +++- FHE/Ring.cpp | 2 +- FHE/Ring_Element.cpp | 3 +- FHE/Rq_Element.cpp | 8 +- FHE/Rq_Element.h | 8 +- FHEOffline/Multiplier.cpp | 18 +- FHEOffline/PairwiseSetup.cpp | 26 +- GC/NoShare.h | 12 +- Machines/dealer-ring-party.cpp | 2 + Machines/mama-party.cpp | 2 +- Makefile | 1 + Math/FixedVec.h | 7 +- Math/Setup.cpp | 4 +- Math/Z2k.h | 6 + Math/Z2k.hpp | 5 +- Math/Zp_Data.cpp | 3 +- Math/gf2n.cpp | 15 +- Math/gf2n.h | 4 + Math/gfpvar.h | 6 + Networking/AllButLastPlayer.h | 17 +- Networking/CryptoPlayer.cpp | 4 +- Networking/Player.cpp | 8 - Networking/Player.h | 1 - Processor/Data_Files.hpp | 2 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 10 + Processor/Instruction.hpp | 72 ++++-- Processor/Machine.hpp | 7 +- Processor/OnlineMachine.hpp | 4 +- Processor/Processor.h | 21 ++ Processor/Processor.hpp | 56 +++++ Processor/RingMachine.hpp | 5 +- Programs/Source/dijkstra_example.mpc | 50 ++++ Programs/Source/dijkstra_tutorial.mpc | 9 - Protocols/Dealer.h | 36 +++ Protocols/DealerInput.h | 1 + Protocols/DealerInput.hpp | 13 +- Protocols/DealerMC.h | 1 + Protocols/DealerMC.hpp | 7 + Protocols/DealerMatrixPrep.h | 32 +++ Protocols/DealerMatrixPrep.hpp | 87 +++++++ Protocols/DealerPrep.h | 13 + Protocols/DealerPrep.hpp | 51 ++++ Protocols/DealerShare.h | 14 +- Protocols/FakeShare.h | 1 + Protocols/Hemi.h | 10 +- Protocols/Hemi.hpp | 41 +++- Protocols/HemiShare.h | 4 + Protocols/MAC_Check.h | 3 +- Protocols/MAC_Check.hpp | 11 +- Protocols/MAC_Check_Base.h | 4 + Protocols/MAC_Check_Base.hpp | 7 + Protocols/MaliciousRep3Share.h | 4 + Protocols/MaliciousRepMC.hpp | 2 +- Protocols/MaliciousShamirShare.h | 2 + Protocols/Rep3Share.h | 1 + Protocols/Rep4Share.h | 2 + Protocols/Replicated.h | 4 +- Protocols/Replicated.hpp | 12 +- Protocols/ReplicatedInput.h | 2 +- Protocols/ReplicatedMC.hpp | 2 +- Protocols/SecureShuffle.h | 53 +++++ Protocols/SecureShuffle.hpp | 328 ++++++++++++++++++++++++++ Protocols/SemiShare.h | 1 + Protocols/ShamirShare.h | 1 + Protocols/Share.h | 1 + Protocols/ShareInterface.h | 9 +- Protocols/ShareMatrix.h | 175 +++++++++++++- Protocols/TemiShare.h | 3 + Protocols/fake-stuff.hpp | 29 ++- README.md | 5 +- Tools/Exceptions.cpp | 6 +- Tools/Exceptions.h | 2 +- Tools/PointerVector.h | 9 + Tools/Waksman.cpp | 91 +++++++ Tools/Waksman.h | 39 +++ Utils/he-example.cpp | 97 ++++++++ doc/Doxyfile | 2 +- doc/homomorphic-encryption.rst | 31 +++ doc/index.rst | 1 + doc/troubleshooting.rst | 2 + 108 files changed, 2228 insertions(+), 543 deletions(-) create mode 100644 Compiler/sorting.py create mode 100644 Programs/Source/dijkstra_example.mpc delete mode 100644 Programs/Source/dijkstra_tutorial.mpc create mode 100644 Protocols/Dealer.h create mode 100644 Protocols/DealerMatrixPrep.h create mode 100644 Protocols/DealerMatrixPrep.hpp create mode 100644 Protocols/SecureShuffle.h create mode 100644 Protocols/SecureShuffle.hpp create mode 100644 Tools/Waksman.cpp create mode 100644 Tools/Waksman.h create mode 100644 Utils/he-example.cpp create mode 100644 doc/homomorphic-encryption.rst diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 84ba909b3..beddd64cf 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -259,7 +259,6 @@ ProgramParty::~ProgramParty() reset(); if (P) { - cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl; delete P; } delete[] eval_threads; diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 8e16c3077..ae69cb7f5 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -28,7 +28,7 @@ RealProgramParty* RealProgramParty::singleton = 0; template RealProgramParty::RealProgramParty(int argc, const char** argv) : - garble_processor(garble_machine), dummy_proc({{}, 0}) + garble_processor(garble_machine), dummy_proc({}, 0) { assert(singleton == 0); singleton = this; @@ -157,6 +157,9 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : MC->Check(*P); data_sent = P->total_comm().sent; + if (online_opts.verbose) + P->total_comm().print(); + this->machine.write_memory(this->N.my_num()); } diff --git a/CHANGELOG.md b/CHANGELOG.md index 744d0ff1b..18cc92ae3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.2 (Mai 27, 2022) + +- Secure shuffling +- O(n log n) radix sorting +- Documented BGV encryption interface +- Optimized matrix multiplication in dealer protocol +- Fixed security bug in homomorphic encryption parameter generation +- Fixed Security bug in Temi matrix multiplication + ## 0.3.1 (Apr 19, 2022) - Protocol in dealer model diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index b34e68c82..fdd987225 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -382,7 +382,6 @@ class sbits(bits): reg_type = 'sb' is_clear = False clear_type = cbits - default_type = cbits load_inst = (inst.ldmsbi, inst.ldmsb) store_inst = (inst.stmsbi, inst.stmsb) bitdec = inst.bitdecs @@ -404,6 +403,9 @@ def new(value=None, n=None): else: return sbits.get_type(n)(value) @staticmethod + def _new(value): + return value + @staticmethod def get_random_bit(): res = sbit() inst.bitb(res) @@ -909,6 +911,7 @@ class cbit(bit, cbits): sbits.bit_type = sbit cbits.bit_type = cbit sbit.clear_type = cbit +sbits.default_type = sbits class bitsBlock(oram.Block): value_type = sbits diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 8a10ee58c..5f5b82dbc 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -17,6 +17,7 @@ import itertools import operator +import math from . import tools from random import randint from functools import reduce @@ -2406,6 +2407,70 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) +@base.gf2n +class secshuffle(base.VectorInstruction, base.DataInstruction): + """ Secure shuffling. + + :param: destination (sint) + :param: source (sint) + """ + __slots__ = [] + code = base.opcodes['SECSHUFFLE'] + arg_format = ['sw','s','int'] + + def __init__(self, *args, **kwargs): + super(secshuffle_class, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', 0), float('inf')) + +class gensecshuffle(base.DataInstruction): + """ Generate secure shuffle to bit used several times. + + :param: destination (regint) + :param: size (int) + + """ + __slots__ = [] + code = base.opcodes['GENSECSHUFFLE'] + arg_format = ['ciw','int'] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', 0), float('inf')) + +class applyshuffle(base.VectorInstruction, base.DataInstruction): + """ Generate secure shuffle to bit used several times. + + :param: destination (sint) + :param: source (sint) + :param: number of elements to be treated as one (int) + :param: handle (regint) + :param: reverse (0/1) + + """ + __slots__ = [] + code = base.opcodes['APPLYSHUFFLE'] + arg_format = ['sw','s','int','ci','int'] + + def __init__(self, *args, **kwargs): + super(applyshuffle, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'triple', 0), float('inf')) + +class delshuffle(base.Instruction): + """ Delete secure shuffle. + + :param: handle (regint) + + """ + code = base.opcodes['DELSHUFFLE'] + arg_format = ['ci'] + class check(base.Instruction): """ Force MAC check in current thread and all idle thread if current diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 8ae0b86fc..d598d8a71 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -106,6 +106,11 @@ CONV2DS = 0xAC, CHECK = 0xAF, PRIVATEOUTPUT = 0xAD, + # Shuffling + SECSHUFFLE = 0xFA, + GENSECSHUFFLE = 0xFB, + APPLYSHUFFLE = 0xFC, + DELSHUFFLE = 0xFD, # Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Compiler/oram.py b/Compiler/oram.py index 543fc4aab..d4b434385 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -348,7 +348,7 @@ def __iter__(self): def __len__(self): return 2 + len(self.x) def __repr__(self): - return '{empty=%s}' % self.is_empty if self.is_empty \ + return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \ else '{%s: %s}' % (self.v, self.x) def __add__(self, other): try: @@ -466,12 +466,14 @@ class AbstractORAM(object): def get_array(size, t, *args, **kwargs): return t.dynamic_array(size, t, *args, **kwargs) def read(self, index): - return self._read(self.value_type.hard_conv(index)) + res = self._read(self.index_type.hard_conv(index)) + res = [self.value_type._new(x) for x in res] + return res def write(self, index, value): + value = util.tuplify(value) + value = [self.value_type.conv(x) for x in value] new_value = [self.value_type.get_type(length).hard_conv(v) \ - for length,v in zip(self.entry_size, value \ - if isinstance(value, (tuple, list)) \ - else (value,))] + for length,v in zip(self.entry_size, value)] return self._write(self.index_type.hard_conv(index), *new_value) def access(self, index, new_value, write, new_empty=False): return self._access(self.index_type.hard_conv(index), @@ -795,7 +797,8 @@ def batch_init(self, values): for i,value in enumerate(values): index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.ram[i] = Entry(index, new_value, value_type=self.value_type) @@ -986,7 +989,8 @@ def batch_init(self, values): for i,value in enumerate(values): index = self.value_type.hard_conv(i) new_value = [self.value_type.hard_conv(v) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.__setitem__(index, new_value) def __repr__(self): @@ -1062,11 +1066,12 @@ def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ stop_timer(1) start_timer() self.root = RefBucket(1, self) - self.index = self.index_structure(size, self.D, value_type, init_rounds, True) + self.index = self.index_structure(size, self.D, self.index_type, + init_rounds, True) - self.read_value = Array(self.value_length, value_type) + self.read_value = Array(self.value_length, value_type.default_type) self.read_non_empty = MemValue(self.value_type.bit_type(0)) - self.state = MemValue(self.value_type(0)) + self.state = MemValue(self.value_type.default_type(0)) @method_block def add_to_root(self, state, is_empty, v, *x): if len(x) != self.value_length: @@ -1106,10 +1111,10 @@ def evict2(self, p_bucket1, p_bucket2, d): self.evict_bucket(RefBucket(p_bucket2, self), d) @method_block def read_and_renew_index(self, u): - l_star = random_block(self.D, self.value_type) + l_star = random_block(self.D, self.index_type) if use_insecure_randomness: new_path = regint.get_random(self.D) - l_star = self.value_type(new_path) + l_star = self.index_type(new_path) self.state.write(l_star) return self.index.update(u, l_star, evict=False).reveal() @method_block @@ -1120,7 +1125,7 @@ def read_and_remove_levels(self, u, read_path): parallel = get_parallel(self.index_size, *self.internal_value_type()) @map_sum(get_n_threads_for_tree(self.size), parallel, levels, \ self.value_length + 1, [self.value_type.bit_type] + \ - [self.value_type] * self.value_length) + [self.value_type.default_type] * self.value_length) def process(level): b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level)) bucket = RefBucket(b_index, self) @@ -1142,9 +1147,9 @@ def f(): Program.prog.curr_tape.start_new_basicblock() crash() def internal_value_type(self): - return self.value_type, self.value_length + 1 + return self.value_type.default_type, self.value_length + 1 def internal_entry_size(self): - return self.value_type, [self.D] + list(self.entry_size) + return self.value_type.default_type, [self.D] + list(self.entry_size) def n_buckets(self): return 2**(self.D+1) @method_block @@ -1176,8 +1181,9 @@ def add(self, entry, state=None, evict=True): #print 'pre-add', self maybe_start_timer(4) self.add_to_root(state, entry.empty(), \ - self.value_type(entry.v.read()), \ - *(self.value_type(i.read()) for i in entry.x)) + self.index_type(entry.v.read()), \ + *(self.value_type.default_type(i.read()) + for i in entry.x)) maybe_stop_timer(4) #print 'pre-evict', self if evict: @@ -1228,21 +1234,27 @@ def batch_init(self, values): raise CompilerError('Batch initialization only possible with sint.') depth = log2(m) - leaves = [0] * m - entries = [0] * m - indexed_values = [0] * m + leaves = self.value_type.Array(m) + indexed_values = \ + self.value_type.Matrix(m, len(util.tuplify(values[0])) + 1) # assign indices 0, ..., m-1 - for i,value in enumerate(values): + @for_range(m) + def _(i): + value = values[i] index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ for v in (value if isinstance(value, (tuple, list)) \ else (value,))] indexed_values[i] = [index] + new_value - + entries = sint.Matrix(self.bucket_size * 2 ** self.D, + len(Entry(0, list(indexed_values[0]), False))) + # assign leaves - for i,index_value in enumerate(indexed_values): + @for_range(len(indexed_values)) + def _(i): + index_value = list(indexed_values[i]) leaves[i] = random_block(self.D, self.value_type) index = index_value[0] @@ -1252,18 +1264,20 @@ def batch_init(self, values): # save unsorted leaves for position map unsorted_leaves = [MemValue(self.value_type(leaf)) for leaf in leaves] - permutation.sort(leaves, comp=permutation.normal_comparator) + leaves.sort() bucket_sz = 0 # B[i] = (pos, leaf, "last in bucket" flag) for i-th entry - B = [[0]*3 for i in range(m)] + B = sint.Matrix(m, 3) B[0] = [0, leaves[0], 0] B[-1] = [None, None, sint(1)] - s = 0 + s = MemValue(sint(0)) - for i in range(1, m): + @for_range_opt(m - 1) + def _(j): + i = j + 1 eq = leaves[i].equal(leaves[i-1]) - s = (s + eq) * eq + s.write((s + eq) * eq) B[i][0] = s B[i][1] = leaves[i] B[i-1][2] = 1 - eq @@ -1271,7 +1285,7 @@ def batch_init(self, values): #last_in_bucket[i-1] = 1 - eq # shuffle - permutation.shuffle(B, value_type=sint) + B.secure_shuffle() #cint(0).print_reg('shuf') sz = MemValue(0) #cint(0) @@ -1279,7 +1293,8 @@ def batch_init(self, values): empty_positions = Array(nleaves, self.value_type) empty_leaves = Array(nleaves, self.value_type) - for i in range(m): + @for_range(m) + def _(i): if_then(reveal(B[i][2])) #if B[i][2] == 1: #cint(i).print_reg('last') @@ -1291,12 +1306,13 @@ def batch_init(self, values): empty_positions[szval] = B[i][0] #pos[i][0] #empty_positions[szval].reveal().print_reg('ps0') empty_leaves[szval] = B[i][1] #pos[i][1] - sz += 1 + sz.iadd(1) end_if() - pos_bits = [] + pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2) - for i in range(nleaves): + @for_range_opt(nleaves) + def _(i): leaf = empty_leaves[i] # split into 2 if bucket size can't fit into one field elem if self.bucket_size + Program.prog.security > 128: @@ -1315,46 +1331,39 @@ def batch_init(self, values): bucket_bits = [b for sl in zip(bits2,bits) for b in sl] else: bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0] - pos_bits += [[b, leaf] for b in bucket_bits] + assert len(bucket_bits) == self.bucket_size + for j, b in enumerate(bucket_bits): + pos_bits[i * self.bucket_size + j] = [b, leaf] # sort to get empty positions first - permutation.sort(pos_bits, comp=permutation.bitwise_list_comparator) + pos_bits.sort(n_bits=1) # now assign positions to empty entries - empty_entries = [0] * (self.bucket_size*2**self.D - m) - - for i in range(self.bucket_size*2**self.D - m): + @for_range(len(entries) - m) + def _(i): vtype, vlength = self.internal_value_type() leaf = vtype(pos_bits[i][1]) # set leaf in empty entry for assigning after shuffle - value = tuple([leaf] + [vtype(0) for j in range(vlength)]) + value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)]) entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype) - empty_entries[i] = entry + entries[m + i] = entry # now shuffle, reveal positions and place entries - entries = entries + empty_entries - while len(entries) & (len(entries)-1) != 0: - entries.append(None) - permutation.shuffle(entries, value_type=sint) - entries = [entry for entry in entries if entry is not None] - clear_leaves = [MemValue(entry.x[0].reveal()) for entry in entries] + entries.secure_shuffle() + clear_leaves = Array.create_from( + Entry(entries.get_columns()).x[0].reveal()) Program.prog.curr_tape.start_new_basicblock() bucket_sizes = Array(2**self.D, regint) for i in range(2**self.D): bucket_sizes[i] = 0 - k = 0 - for entry,leaf in zip(entries, clear_leaves): - leaf = leaf.read() - k += 1 - - # for some reason leaf_buckets is in bit-reversed order - bits = bit_decompose(leaf, self.D) - rev_leaf = sum(b*2**i for i,b in enumerate(bits[::-1])) - bucket = RefBucket(rev_leaf + (1 << self.D), self) - # hack: 1*entry ensures MemValues are converted to sints - bucket.bucket.ram[bucket_sizes[leaf]] = 1*entry + + @for_range_opt(len(entries)) + def _(k): + leaf = clear_leaves[k] + bucket = RefBucket(leaf + (1 << self.D), self) + bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k]) bucket_sizes[leaf] += 1 self.index.batch_init([leaf.read() for leaf in unsorted_leaves]) @@ -1599,16 +1608,20 @@ def __setitem__(self, index, value): def batch_init(self, values): """ Initialize m values with indices 0, ..., m-1 """ m = len(values) - n_entries = max(1, m/self.entries_per_block) - new_values = [0] * n_entries + n_entries = max(1, m//self.entries_per_block) + new_values = sint.Matrix(n_entries, self.elements_per_block) + values = Array.create_from(values) - for i in range(n_entries): + @for_range(n_entries) + def _(i): block = [0] * self.elements_per_block for j in range(self.elements_per_block): base = i * self.entries_per_block + j * self.entries_per_element for k in range(self.entries_per_element): - if base + k < m: - block[j] += values[base + k] << (k * self.entry_size) + @if_(base + k < m) + def _(): + block[j] += \ + values[base + k] << (k * sum(self.entry_size)) new_values[i] = block @@ -1667,7 +1680,8 @@ def OptimalORAM(size,*args,**kwargs): experiments. :param size: number of elements - :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` + :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / + :py:class:`sfix` """ if optimal_threshold is None: if n_threads == 1: @@ -1784,7 +1798,7 @@ def test_batch_init(oram_type, N): oram = oram_type(N, value_type) print('initialized') print_reg(cint(0), 'init') - oram.batch_init([value_type(i) for i in range(N)]) + oram.batch_init(Array.create_from(sint(regint.inc(N)))) print_reg(cint(0), 'done') @for_range(N) def f(i): diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index fb1601c3d..b9e3952ba 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -111,24 +111,6 @@ def bucket_size_sorter(x, y): return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z]) -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - - - Returns the network switching config so it may be re-used later. """ - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = permutation.configure_waksman(permutation.random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type(b) for b in c] - permutation.waksman(x, config, reverse=reverse) - permutation.waksman(x, config, reverse=reverse) - - return config - - def LT(a, b): a_bits = bit_decompose(a) b_bits = bit_decompose(b) @@ -472,10 +454,15 @@ def f(): print_ln() # shuffle entries and levels - while len(merged_entries) & (len(merged_entries)-1) != 0: - merged_entries.append(None) #self.root.bucket.empty_entry(False)) - permutation.rec_shuffle(merged_entries, value_type=self.value_type) - merged_entries = [e for e in merged_entries if e is not None] + flat = [] + for x in merged_entries: + flat += list(x[0]) + [x[1]] + flat = self.value_type(flat) + assert len(flat) % len(merged_entries) == 0 + l = len(flat) // len(merged_entries) + shuffled = flat.secure_shuffle(l) + merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]] + for i in range(len(shuffled) // l)] # need to copy entries/levels to memory for re-positioning entries_ram = RAM(self.temp_size, self.entry_type, self.get_array) diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 6e1273ec4..07d3a3e70 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -10,16 +10,6 @@ from Compiler.program import Program _Array = Array -SORT_BITS = [] -insecure_random = Random(0) - -def predefined_comparator(x, y): - """ Assumes SORT_BITS is populated with the required sorting network bits """ - if predefined_comparator.sort_bits_iter is None: - predefined_comparator.sort_bits_iter = iter(SORT_BITS) - return next(predefined_comparator.sort_bits_iter) -predefined_comparator.sort_bits_iter = None - def list_comparator(x, y): """ Uses the first element in the list for comparison """ return x[0] < y[0] @@ -37,10 +27,6 @@ def bitwise_comparator(x, y): def cond_swap_bit(x,y, b): """ swap if b == 1 """ - if x is None: - return y, None - elif y is None: - return x, None if isinstance(x, list): t = [(xi - yi) * b for xi,yi in zip(x, y)] return [xi - ti for xi,ti in zip(x, t)], \ @@ -87,23 +73,6 @@ def odd_even_merge_sort(a, comp=bitwise_comparator): else: raise CompilerError('Length of list must be power of two') -def merge(a, b, comp): - """ General length merge (pads to power of 2) """ - while len(a) & (len(a)-1) != 0: - a.append(None) - while len(b) & (len(b)-1) != 0: - b.append(None) - if len(a) < len(b): - a += [None] * (len(b) - len(a)) - elif len(b) < len(a): - b += [None] * (len(b) - len(b)) - t = a + b - odd_even_merge(t, comp) - for i,v in enumerate(t[::]): - if v is None: - t.remove(None) - return t - def sort(a, comp): """ Pads to power of 2, sorts, removes padding """ length = len(a) @@ -112,47 +81,12 @@ def sort(a, comp): odd_even_merge_sort(a, comp) del a[length:] -def recursive_merge(a, comp): - """ Recursively merge a list of sorted lists (initially sorted by size) """ - if len(a) == 1: - return - # merge smallest two lists, place result in correct position, recurse - t = merge(a[0], a[1], comp) - del a[0] - del a[0] - added = False - for i,c in enumerate(a): - if len(c) >= len(t): - a.insert(i, t) - added = True - break - if not added: - a.append(t) - recursive_merge(a, comp) - -def random_perm(n): - """ Generate a random permutation of length n - - WARNING: randomness fixed at compile-time, this is NOT secure - """ - if not Program.prog.options.insecure: - raise CompilerError('no secure implementation of Waksman permution, ' - 'use --insecure to activate') - a = list(range(n)) - for i in range(n-1, 0, -1): - j = insecure_random.randint(0, i) - t = a[i] - a[i] = a[j] - a[j] = t - return a - -def inverse(perm): - inv = [None] * len(perm) - for i, p in enumerate(perm): - inv[p] = i - return inv +# The following functionality for shuffling isn't used any more as it +# has been moved to the virtual machine. The code has been kept for +# reference. -def configure_waksman(perm): +def configure_waksman(perm, n_iter=[0]): + top = n_iter == [0] n = len(perm) if n == 2: return [(perm[0], perm[0])] @@ -175,6 +109,7 @@ def configure_waksman(perm): via = 0 j0 = j while True: + n_iter[0] += 1 #print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2) i = inv_perm[j] @@ -209,8 +144,11 @@ def configure_waksman(perm): assert sorted(p0) == list(range(n//2)) assert sorted(p1) == list(range(n//2)) - p0_config = configure_waksman(p0) - p1_config = configure_waksman(p1) + p0_config = configure_waksman(p0, n_iter) + p1_config = configure_waksman(p1, n_iter) + if top: + print(n_iter[0], 'iterations for Waksman') + assert O[0] == 0, 'not a Waksman network' return [I + O] + [a+b for a,b in zip(p0_config, p1_config)] def waksman(a, config, depth=0, start=0, reverse=False): @@ -358,23 +296,10 @@ def _(i): # nblocks /= 2 # depth -= 1 -def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False): - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = configure_waksman(random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type.bit_type(b) for b in c] - waksman(x, config, reverse=reverse) - waksman(x, config, reverse=reverse) - -def config_shuffle(n, value_type): - """ Compute config for oblivious shuffling. - - Take mod 2 for active sec. """ - perm = random_perm(n) +def config_from_perm(perm, value_type): + n = len(perm) + assert(list(sorted(perm))) == list(range(n)) if n & (n-1) != 0: # pad permutation to power of 2 m = 2**int(math.ceil(math.log(n, 2))) @@ -394,103 +319,3 @@ def _(i): for j,b in enumerate(c): config[i * len(perm) + j] = b return config - -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - WARNING: This is not a properly secure implementation but has roughly the right complexity. - - Returns the network switching config so it may be re-used later. """ - n = len(x) - m = 2**int(math.ceil(math.log(n, 2))) - assert n == m, 'only working for powers of two' - if config is None: - config = config_shuffle(n, value_type) - - if isinstance(x, list): - if isinstance(x[0], list): - length = len(x[0]) - assert len(x) == length - for i in range(length): - xi = Array(m, value_type.reg_type) - for j in range(n): - xi[j] = x[j][i] - for j in range(n, m): - xi[j] = value_type(0) - iter_waksman(xi, config, reverse=reverse) - iter_waksman(xi, config, reverse=reverse) - for j, y in enumerate(xi): - x[j][i] = y - else: - xa = Array(m, value_type.reg_type) - for i in range(n): - xa[i] = x[i] - for i in range(n, m): - xa[i] = value_type(0) - iter_waksman(xa, config, reverse=reverse) - iter_waksman(xa, config, reverse=reverse) - x[:] = xa - elif isinstance(x, Array): - if len(x) != m and config is None: - raise CompilerError('Non-power of 2 Array input not yet supported') - iter_waksman(x, config, reverse=reverse) - iter_waksman(x, config, reverse=reverse) - else: - raise CompilerError('Invalid type for shuffle:', type(x)) - - return config - -def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None): - """ Shuffle a list of ORAM entries. - - Randomly permutes the first "perm_size" entries, leaving the rest (empty - entry padding) in the same position. """ - n = len(x) - l = len(x[0]) - if n & (n-1) != 0: - raise CompilerError('Entries must be padded to power of two length.') - if perm_size is None: - perm_size = n - - xarrays = [Array(n, value_type.reg_type) for i in range(l)] - for i in range(n): - for j,value in enumerate(x[i]): - if isinstance(value, MemValue): - xarrays[j][i] = value.read() - else: - xarrays[j][i] = value - - if config is None: - config = config_shuffle(perm_size, value_type) - for xi in xarrays: - shuffle(xi, config, value_type, reverse) - for i in range(n): - x[i] = entry_cls(xarrays[j][i] for j in range(l)) - return config - - -def sort_zeroes(bits, x, n_ones, value_type): - """ Return Array of values in "x" where the corresponding bit in "bits" is - a 0. - - The total number of zeroes in "bits" must be known. - "bits" and "x" must be Arrays. """ - config = config_shuffle(len(x), value_type) - shuffle(bits, config=config, value_type=value_type) - shuffle(x, config=config, value_type=value_type) - result = Array(n_ones, value_type.reg_type) - - sz = MemValue(0) - last_x = MemValue(value_type(0)) - #for i,b in enumerate(bits): - #if_then(b.reveal() == 0) - #result[sz.read()] = x[i] - #sz += 1 - #end_if() - @for_range(len(bits)) - def f(i): - found = (bits[i].reveal() == 0) - szval = sz.read() - result[szval] = last_x + (x[i] - last_x) * found - sz.write(sz + found) - last_x.write(result[szval]) - return result diff --git a/Compiler/sorting.py b/Compiler/sorting.py new file mode 100644 index 000000000..248b3ea07 --- /dev/null +++ b/Compiler/sorting.py @@ -0,0 +1,54 @@ +import itertools +from Compiler import types, library, instructions + +def dest_comp(B): + Bt = B.transpose() + Bt_flat = Bt.get_vector() + St_flat = Bt.value_type.Array(len(Bt_flat)) + St_flat.assign(Bt_flat) + @library.for_range(len(St_flat) - 1) + def _(i): + St_flat[i + 1] = St_flat[i + 1] + St_flat[i] + Tt_flat = Bt.get_vector() * St_flat.get_vector() + Tt = types.Matrix(*Bt.sizes, B.value_type) + Tt.assign_vector(Tt_flat) + return sum(Tt) - 1 + +def reveal_sort(k, D, reverse=False): + assert len(k) == len(D) + library.break_point() + shuffle = types.sint.get_secure_shuffle(len(k)) + k_prime = k.get_vector().secure_permute(shuffle).reveal() + idx = types.Array.create_from(k_prime) + if reverse: + D.assign_vector(D.get_slice_vector(idx)) + library.break_point() + D.secure_permute(shuffle, reverse=True) + else: + D.secure_permute(shuffle) + library.break_point() + v = D.get_vector() + D.assign_slice_vector(idx, v) + library.break_point() + instructions.delshuffle(shuffle) + +def radix_sort(k, D, n_bits=None, signed=True): + assert len(k) == len(D) + bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) + if signed and len(bs) > 1: + bs[-1][:] = bs[-1][:].bit_not() + B = types.sint.Matrix(len(k), 2) + h = types.Array.create_from(types.sint(types.regint.inc(len(k)))) + @library.for_range(len(bs)) + def _(i): + b = bs[i] + B.set_column(0, 1 - b.get_vector()) + B.set_column(1, b.get_vector()) + c = types.Array.create_from(dest_comp(B)) + reveal_sort(c, h, reverse=False) + @library.if_e(i < len(bs) - 1) + def _(): + reveal_sort(h, bs[i + 1], reverse=True) + @library.else_ + def _(): + reveal_sort(h, D, reverse=True) diff --git a/Compiler/types.py b/Compiler/types.py index 1d06f3f71..098f493f0 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1937,6 +1937,11 @@ def matrix_mul(cls, A, B, n, res_params=None): matmuls(res, A, B, n_rows, n, n_cols) return res + @staticmethod + def _new(self): + # mirror sfix + return self + @no_doc def __init__(self, reg_type, val=None, size=None): if isinstance(val, self.clear_type): @@ -2093,6 +2098,12 @@ def square(self): else: return self * self + @set_instruction_type + def secure_shuffle(self, unit_size=1): + res = type(self)(size=self.size) + secshuffle(res, self, unit_size) + return res + @set_instruction_type @vectorize def reveal(self): @@ -2741,6 +2752,17 @@ def private_division(self, divisor, active=True, dividend_length=None, return w + @staticmethod + def get_secure_shuffle(n): + res = regint() + gensecshuffle(res, n) + return res + + def secure_permute(self, shuffle, unit_size=1, reverse=False): + res = sint(size=self.size) + applyshuffle(res, self, unit_size, shuffle, reverse) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -4291,6 +4313,10 @@ class revealed_fix(self.clear_type): k = self.k return revealed_fix._new(val) + def bit_decompose(self, n_bits=None): + """ Bit decomposition. """ + return self.v.bit_decompose(n_bits or self.k) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` @@ -4312,6 +4338,8 @@ class sfix(_fix): int_type = sint bit_type = sintbit clear_type = cfix + get_type = staticmethod(lambda n: sint) + default_type = sint @vectorized_classmethod def get_input_from(cls, player): @@ -4385,6 +4413,10 @@ def expand_to_vector(self, size): def coerce(self, other): return parse_type(other, k=self.k, f=self.f) + def hard_conv_me(self, cls): + assert cls == sint + return self.v + def mul_no_reduce(self, other, res_params=None): assert self.f == other.f assert self.k == other.k @@ -4409,6 +4441,14 @@ def reveal_to(self, player): return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) + def secure_shuffle(self, *args, **kwargs): + return self._new(self.v.secure_shuffle(*args, **kwargs), + k=self.k, f=self.f) + + def secure_permute(self, *args, **kwargs): + return self._new(self.v.secure_permute(*args, **kwargs), + k=self.k, f=self.f) + class unreduced_sfix(_single): int_type = sint @@ -5395,13 +5435,21 @@ def get(self, indices): regint.inc(len(indices), self.address, 0) + indices, size=len(indices)) - def get_slice_vector(self, slice): + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 assert len(slice) <= self.total_size() base = regint.inc(len(slice), slice.address, 1, 1) - inc = regint.inc(len(slice), 0, 1, 1, 1) + inc = regint.inc(len(slice), self.address, 1, 1, 1) addresses = slice.value_type.load_mem(base) + inc - return self.value_type.load_mem(self.address + addresses) + return addresses + + def get_slice_vector(self, slice): + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(addresses) def expand_to_vector(self, index, size): """ Create vector from single entry. @@ -5514,6 +5562,14 @@ def shuffle(self): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) + def secure_shuffle(self): + """ Secure shuffle in place according to the security model. """ + self.assign_vector(self.get_vector().secure_shuffle()) + + def secure_permute(self, *args, **kwargs): + """ Secure permutate in place according to the security model. """ + self.assign_vector(self.get_vector().secure_permute(*args, **kwargs)) + def randomize(self, *args): """ Randomize according to data type. """ self.assign_vector(self.value_type.get_random(*args, size=len(self))) @@ -5570,15 +5626,26 @@ def reveal_to(self, player): """ return personal(player, self.create_from(self[:].reveal_to(player)._v)) - def sort(self, n_threads=None): + def sort(self, n_threads=None, batcher=False, n_bits=None): """ - Sort in place using Batchers' odd-even merge mergesort - with complexity :math:`O(n (\log n)^2)`. + Sort in place using radix sort with complexity :math:`O(n \log + n)` for :py:class:`sint` and :py:class:`sfix`, and Batcher's + odd-even mergesort with :math:`O(n (\log n)^2)` for + :py:class:`sfloat`. :param n_threads: number of threads to use (single thread by - default) + default), need to use Batcher's algorithm for several threads + :param batcher: use Batcher's odd-even mergesort in any case + :param n_bits: number of bits in keys (default: global bit length) """ - library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + if batcher or self.value_type.n_elements() > 1: + library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + else: + if n_threads or 1 > 1: + raise CompilerError('multi-threaded sorting only implemented ' + 'with Batcher\'s odd-even mergesort') + import sorting + sorting.radix_sort(self, self, n_bits=n_bits) def Array(self, size): # compatibility with registers @@ -5619,6 +5686,8 @@ def __getitem__(self, index): :return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise""" if isinstance(index, slice) and index == slice(None): return self.get_vector() + if isinstance(index, int) and index < 0: + index += self.sizes[0] key = program.curr_block, str(index) if key not in self.sub_cache: if util.is_constant(index) and \ @@ -5673,6 +5742,10 @@ def f(i): def total_size(self): return reduce(operator.mul, self.sizes) * self.value_type.n_elements() + def part_size(self): + return reduce(operator.mul, self.sizes[1:]) * \ + self.value_type.n_elements() + def get_vector(self, base=0, size=None): """ Return vector with content. Not implemented for floating-point. @@ -5731,13 +5804,21 @@ def get_slice_vector(self, slice): :param slice: regint array """ + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(self.address + addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(self.address + addresses) + + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) assert len(slice) * part_size <= self.total_size() base = regint.inc(len(slice) * part_size, slice.address, 1, part_size) inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size) addresses = slice.value_type.load_mem(base) * part_size + inc - return self.value_type.load_mem(self.address + addresses) + return addresses def get_addresses(self, *indices): assert self.value_type.n_elements() == 1 @@ -6218,6 +6299,31 @@ def diag(self): n = self.sizes[0] return self.array.get(regint.inc(n, 0, n + 1)) + def secure_shuffle(self): + """ Securely shuffle rows (first index). """ + self.assign_vector(self.get_vector().secure_shuffle(self.part_size())) + + def secure_permute(self, permutation, reverse=False): + """ Securely permute rows (first index). """ + self.assign_vector(self.get_vector().secure_permute( + permutation, self.part_size(), reverse)) + + def sort(self, key_indices=None, n_bits=None): + """ Sort sub-arrays (different first index) in place. + + :param key_indices: indices to sorting keys, for example + ``(1, 2)`` to sort three-dimensional array ``a`` by keys + ``a[*][1][2]``. Default is ``(0, ..., 0)`` of correct length. + :param n_bits: number of bits in keys (default: global bit length) + + """ + if key_indices is None: + key_indices = (0,) * (len(self.sizes) - 1) + key_indices = (None,) + util.tuplify(key_indices) + import sorting + keys = self.get_vector_by_indices(*key_indices) + sorting.radix_sort(keys, self, n_bits=n_bits) + def randomize(self, *args): """ Randomize according to data type. """ if self.total_size() < program.options.budget: @@ -6334,6 +6440,18 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + @staticmethod + def create_from(rows): + rows = list(rows) + if isinstance(rows[0], (list, tuple)): + t = type(rows[0][0]) + else: + t = type(rows[0]) + res = Matrix(len(rows), len(rows[0]), t) + for i in range(len(rows)): + res[i].assign(rows[i]) + return res + def get_column(self, index): """ Get column as vector. @@ -6344,6 +6462,9 @@ def get_column(self, index): self.sizes[1]) return self.value_type.load_mem(addresses) + def get_columns(self): + return (self.get_column(i) for i in range(self.sizes[1])) + def get_column_by_row_indices(self, rows, column): assert self.value_type.n_elements() == 1 addresses = rows * self.sizes[1] + \ diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 4e1e4c4b0..fc5571b1d 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -47,17 +47,8 @@ inline void receive(client_socket* socket, octet* data, size_t len) #else typedef ssl_ctx client_ctx; +typedef ssl_socket client_socket; -class client_socket : public ssl_socket -{ -public: - client_socket(boost::asio::io_service& io_service, - boost::asio::ssl::context& ctx, int plaintext_socket, string other, - string me, bool client) : - ssl_socket(io_service, ctx, plaintext_socket, other, me, client) - { - } -}; #endif /** diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h index 1efe1e228..b0a287444 100644 --- a/FHE/AddableVector.h +++ b/FHE/AddableVector.h @@ -58,7 +58,8 @@ class AddableVector: public vector { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - for (unsigned int i = 0; i < this->size(); i++) + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) (*this)[i] += y[i]; return *this; } @@ -67,9 +68,11 @@ class AddableVector: public vector { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - AddableVector res(y.size()); - for (unsigned int i = 0; i < this->size(); i++) - res[i] = (*this)[i] - y[i]; + AddableVector res; + res.reserve(y.size()); + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) + res.push_back((*this)[i] - y[i]); return res; } diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 9afef83ce..00e051318 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -31,6 +31,12 @@ word check_pk_id(word a, word b) } +void Ciphertext::Scale() +{ + Scale(params->get_plaintext_modulus()); +} + + void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1) { if (c0.params!=c1.params) { throw params_mismatch(); } @@ -115,9 +121,28 @@ void Ciphertext::add(octetStream& os) *this += tmp; } +void Ciphertext::rerandomize(const FHE_PK& pk) +{ + Rq_Element tmp(*params); + SeededPRNG G; + vector r(params->FFTD()[0].m()); + bigint p = pk.p(); + assert(p != 0); + for (auto& x : r) + { + G.get(x, params->p0().numBits() - p.numBits() - 1); + x *= p; + } + tmp.from(r, 0); + Scale(); + cc0 += tmp; + auto zero = pk.encrypt(*params); + zero.Scale(pk.p()); + *this += zero; +} + template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); -template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); - - +template void mul(Ciphertext& ans, const Plaintext& a, + const Ciphertext& c); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index d455f1268..11a23e2ab 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -15,6 +15,12 @@ template void mul(Ciphertext& ans,const Ciphertext& c, void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1); void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,const FHE_PK& pk); +/** + * BGV ciphertext. + * The class allows adding two ciphertexts as well as adding a plaintext and + * a ciphertext via operator overloading. The multiplication of two ciphertexts + * requires the public key and thus needs a separate function. + */ class Ciphertext { Rq_Element cc0,cc1; @@ -54,6 +60,7 @@ class Ciphertext // Scale down an element from level 1 to level 0, if at level 0 do nothing void Scale(const bigint& p) { cc0.Scale(p); cc1.Scale(p); } + void Scale(); // Throws error if ans,c0,c1 etc have different params settings // - Thus programmer needs to ensure this rather than this being done @@ -90,6 +97,12 @@ class Ciphertext template Ciphertext& operator*=(const Plaintext_& other) { ::mul(*this, *this, other); return *this; } + /** + * Ciphertext multiplication. + * @param pk public key + * @param x second ciphertext + * @returns product ciphertext + */ Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const { Ciphertext res(*params); ::mul(res, *this, x, pk); return res; } @@ -98,14 +111,18 @@ class Ciphertext return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this}; } + /// Re-randomize for circuit privacy. + void rerandomize(const FHE_PK& pk); + int level() const { return cc0.level(); } - // pack/unpack (like IO) also assume params are known and already set - // correctly + /// Append to buffer void pack(octetStream& o) const { cc0.pack(o); cc1.pack(o); o.store(pk_id); } - void unpack(octetStream& o) - { cc0.unpack(o); cc1.unpack(o); o.get(pk_id); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& o) + { cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); } void output(ostream& s) const { cc0.output(s); cc1.output(s); s.write((char*)&pk_id, sizeof(pk_id)); } diff --git a/FHE/Diagonalizer.cpp b/FHE/Diagonalizer.cpp index 9cc1a0840..958cd28cc 100644 --- a/FHE/Diagonalizer.cpp +++ b/FHE/Diagonalizer.cpp @@ -64,8 +64,11 @@ Diagonalizer::MatrixVector Diagonalizer::dediag( { auto& c = products.at(i); for (int j = 0; j < n_matrices; j++) + { + res.at(j).entries.init(); for (size_t k = 0; k < n_rows; k++) res.at(j)[{k, i}] = c.element(j * n_rows + k); + } } return res; } diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index c71a4c5da..d3b67b506 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -7,6 +7,11 @@ +FFT_Data::FFT_Data() : + twop(-1) +{ +} + void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) { R=Rg; diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index c5d6b2063..4fb37ed48 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -50,7 +50,7 @@ class FFT_Data void pack(octetStream& o) const; void unpack(octetStream& o); - FFT_Data() { ; } + FFT_Data(); FFT_Data(const Ring& Rg,const Zp_Data& PrD) { init(Rg,PrD); } diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 20dfb1bb5..742c85452 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -12,6 +12,11 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p()) { } +FHE_SK::FHE_SK(const FHE_Params& pms) : + FHE_SK(pms, pms.get_plaintext_modulus()) +{ +} + FHE_SK& FHE_SK::operator+=(const FHE_SK& c) { @@ -38,6 +43,11 @@ void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G) } +FHE_PK::FHE_PK(const FHE_Params& pms) : + FHE_PK(pms, pms.get_plaintext_modulus()) +{ +} + Rq_Element FHE_PK::sample_secret_key(PRNG& G) { Rq_Element sk = FHE_SK(*this).s(); @@ -179,32 +189,51 @@ Ciphertext FHE_PK::encrypt(const Plaintext& template Ciphertext FHE_PK::encrypt( const Plaintext& mess) const +{ + return encrypt(Rq_Element(*params, mess)); +} + +Ciphertext FHE_PK::encrypt(const Rq_Element& mess) const { Random_Coins rc(*params); PRNG G; G.ReSeed(); rc.generate(G); - return encrypt(mess, rc); + Ciphertext res(*params); + quasi_encrypt(res, mess, rc); + return res; } template void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const { - if (&c.get_params()!=params) { throw params_mismatch(); } if (T::characteristic_two ^ (pr == 2)) throw pr_mismatch(); + Rq_Element ans = quasi_decrypt(c); + mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); +} + +Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const +{ + if (&c.get_params()!=params) { throw params_mismatch(); } + Rq_Element ans; mul(ans,c.c1(),sk); sub(ans,c.c0(),ans); ans.change_rep(polynomial); - mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); + return ans; } +Plaintext_ FHE_SK::decrypt(const Ciphertext& c) +{ + return decrypt(c, params->get_plaintext_field_data()); +} + template Plaintext FHE_SK::decrypt(const Ciphertext& c, const FD& FieldD) { @@ -299,12 +328,12 @@ void FHE_PK::unpack(octetStream& o) o.consume((octet*) tag, 8); if (memcmp(tag, "PKPKPKPK", 8)) throw runtime_error("invalid serialization of public key"); - a0.unpack(o); - b0.unpack(o); + a0.unpack(o, *params); + b0.unpack(o, *params); if (params->n_mults() > 0) { - Sw_a.unpack(o); - Sw_b.unpack(o); + Sw_a.unpack(o, *params); + Sw_b.unpack(o, *params); } pr.unpack(o); } @@ -322,7 +351,6 @@ bool FHE_PK::operator!=(const FHE_PK& x) const return false; } - void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const { @@ -345,8 +373,6 @@ void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) throw runtime_error("incorrect key pair"); } - - void FHE_PK::check(const FHE_Params& params, const bigint& pr) const { if (this->pr != pr) @@ -361,6 +387,24 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const } } +bigint FHE_SK::get_noise(const Ciphertext& c) +{ + sk.lower_level(); + Ciphertext cc = c; + if (cc.level()) + cc.Scale(); + Rq_Element tmp = quasi_decrypt(cc); + bigint res; + bigint q = tmp.get_modulus(); + bigint half_q = q / 2; + for (auto& x : tmp.to_vec_bigint()) + { +// cout << numBits(x) << "/" << (x > half_q) << "/" << (x < 0) << " "; + res = max(res, x > half_q ? x - q : x); + } + return res; +} + template void FHE_PK::encrypt(Ciphertext&, const Plaintext_& mess, diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 30ecc2925..f342e203b 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -12,6 +12,10 @@ class FHE_PK; class Ciphertext; +/** + * BGV secret key. + * The class allows addition. + */ class FHE_SK { Rq_Element sk; @@ -29,6 +33,8 @@ class FHE_SK // secret key always on lower level void assign(const Rq_Element& s) { sk=s; sk.lower_level(); } + FHE_SK(const FHE_Params& pms); + FHE_SK(const FHE_Params& pms, const bigint& p) : sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; } @@ -38,8 +44,11 @@ class FHE_SK const Rq_Element& s() const { return sk; } + /// Append to buffer void pack(octetStream& os) const { sk.pack(os); pr.pack(os); } - void unpack(octetStream& os) { sk.unpack(os); pr.unpack(os); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& os) { sk.unpack(os, *params); pr.unpack(os); } // Assumes Ring and prime of mess have already been set correctly // Ciphertext c must be at level 0 or an error occurs @@ -50,9 +59,14 @@ class FHE_SK template Plaintext decrypt(const Ciphertext& c, const FD& FieldD); + /// Decryption for cleartexts modulo prime + Plaintext_ decrypt(const Ciphertext& c); + template void decrypt_any(Plaintext_& mess, const Ciphertext& c); + Rq_Element quasi_decrypt(const Ciphertext& c) const; + // Three stage procedure for Distributed Decryption // - First stage produces my shares // - Second stage adds in another players shares, do this once for each other player @@ -62,7 +76,6 @@ class FHE_SK void dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_number,int num_players) const; void dist_decrypt_2(vector& vv,const vector& vv1) const; - friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); /* Add secret keys @@ -82,10 +95,15 @@ class FHE_SK template void check(const FHE_PK& pk, const FD& FieldD); + bigint get_noise(const Ciphertext& c); + friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; } }; +/** + * BGV public key. + */ class FHE_PK { Rq_Element a0,b0; @@ -104,8 +122,10 @@ class FHE_PK ) { a0=a; b0=b; Sw_a=sa; Sw_b=sb; } - - FHE_PK(const FHE_Params& pms, const bigint& p = 0) + + FHE_PK(const FHE_Params& pms); + + FHE_PK(const FHE_Params& pms, const bigint& p) : a0(pms.FFTD(),evaluation,evaluation), b0(pms.FFTD(),evaluation,evaluation), Sw_a(pms.FFTD(),evaluation,evaluation), @@ -143,8 +163,11 @@ class FHE_PK template Ciphertext encrypt(const Plaintext& mess, const Random_Coins& rc) const; + + /// Encryption template Ciphertext encrypt(const Plaintext& mess) const; + Ciphertext encrypt(const Rq_Element& mess) const; friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); @@ -156,8 +179,10 @@ class FHE_PK void check_noise(const FHE_SK& sk) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const; - // params setting is done out of these IO/pack/unpack functions + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); bool operator!=(const FHE_PK& x) const; @@ -170,21 +195,39 @@ class FHE_PK void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); +/** + * BGV key pair + */ class FHE_KeyPair { public: + /// Public key FHE_PK pk; + /// Secret key FHE_SK sk; - FHE_KeyPair(const FHE_Params& params, const bigint& pr = 0) : + FHE_KeyPair(const FHE_Params& params, const bigint& pr) : pk(params, pr), sk(params, pr) { } + /// Initialization + FHE_KeyPair(const FHE_Params& params) : + pk(params), sk(params) + { + } + void generate(PRNG& G) { KeyGen(pk, sk, G); } + + /// Generate fresh keys + void generate() + { + SeededPRNG G; + generate(G); + } }; template diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 5fb07f233..5a0f3991c 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -1,5 +1,6 @@ #include "FHE_Params.h" +#include "NTL-Subs.h" #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" #include "Protocols/HemiOptions.h" @@ -67,6 +68,7 @@ void FHE_Params::pack(octetStream& o) const Bval.pack(o); o.store(sec_p); o.store(matrix_dim); + fd.pack(o); } void FHE_Params::unpack(octetStream& o) @@ -80,6 +82,7 @@ void FHE_Params::unpack(octetStream& o) Bval.unpack(o); o.get(sec_p); o.get(matrix_dim); + fd.unpack(o); } bool FHE_Params::operator!=(const FHE_Params& other) const @@ -92,3 +95,37 @@ bool FHE_Params::operator!=(const FHE_Params& other) const else return false; } + +void FHE_Params::basic_generation_mod_prime(int plaintext_length) +{ + if (n_mults() == 0) + generate_semi_setup(plaintext_length, 0, *this, fd, false); + else + { + Parameters parameters(1, plaintext_length, 0); + parameters.generate_setup(*this, fd); + } +} + +template<> +const FFT_Data& FHE_Params::get_plaintext_field_data() const +{ + return fd; +} + +template<> +const P2Data& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +template<> +const PPData& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +bigint FHE_Params::get_plaintext_modulus() const +{ + return fd.get_prime(); +} diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 8821e2e29..4733245ca 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -15,6 +15,9 @@ #include "Tools/random.h" #include "Protocols/config.h" +/** + * Cryptosystem parameters + */ class FHE_Params { protected: @@ -29,8 +32,15 @@ class FHE_Params bigint Bval; int matrix_dim; + FFT_Data fd; + public: + /** + * Initialization. + * @param n_mults number of ciphertext multiplications (0/1) + * @param drown_sec parameter for function privacy (default 40) + */ FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY); int n_mults() const { return FFTData.size() - 1; } @@ -59,10 +69,24 @@ class FHE_Params int phi_m() const { return FFTData[0].phi_m(); } const Ring& get_ring() { return FFTData[0].get_R(); } + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer void unpack(octetStream& o); bool operator!=(const FHE_Params& other) const; + + /** + * Generate parameter for computation modulo a prime + * @param plaintext_length bit length of prime + */ + void basic_generation_mod_prime(int plaintext_length); + + template + const FD& get_plaintext_field_data() const; + + bigint get_plaintext_modulus() const; }; #endif diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 22705bedf..794e7431d 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -107,10 +107,12 @@ int generate_semi_setup(int plaintext_length, int sec, int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up) { +#ifdef VERBOSE cout << "Need ciphertext modulus of length " << lgp0; if (params.n_mults() > 0) cout << "+" << lgp1; cout << " and " << phi_N(m) << " slots" << endl; +#endif int extra_slack = 0; if (round_up) @@ -125,8 +127,10 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, } extra_slack = i - 1; lgp0 += extra_slack; +#ifdef VERBOSE cout << "Rounding up to " << lgp0 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } Ring R; @@ -148,11 +152,15 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool round_up, FHE_Params& params) { + (void) lg2pi, (void) n; + +#ifdef VERBOSE if (n >= 2 and n <= 10) cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2] << ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl; cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl; cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl; +#endif int extra_slack = 0; if (round_up) @@ -171,11 +179,15 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, extra_slack = 2 * i; lg2p0 += i; lg2p1 += i; +#ifdef VERBOSE cout << "Rounding up to " << lg2p0 << "+" << lg2p1 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } +#ifdef VERBOSE cout << "Total length: " << lg2p0 + lg2p1 << endl; +#endif return extra_slack; } @@ -215,12 +227,21 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p, { double phi_m_bound = NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1); + +#ifdef VERBOSE cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl; +#endif + if (phi_N(m) < phi_m_bound) { int old_m = m; + (void) old_m; m = 2 << int(ceil(log2(phi_m_bound))); + +#ifdef VERBOSE cout << "m = " << old_m << " too small, increasing it to " << m << endl; +#endif + generate_prime(p, numBits(p), m); } else @@ -244,6 +265,8 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p, void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, const string& i, const bigint& pr0) { + (void) i; + if (lg2pr==0) { throw invalid_params(); } bigint step=m; @@ -260,13 +283,14 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, assert(numBits(pr) == lg2pr); } +#ifdef VERBOSE cout << "\t pr" << i << " = " << pr << " : " << numBits(pr) << endl; + cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; +#endif assert(pr % m == 1); assert(pr % p == 1); assert(numBits(pr) == lg2pr); - - cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; } /* @@ -626,6 +650,9 @@ void char_2_dimension(int& m, int& lg2) case 16: m = 4369; break; + case 15: + m = 4681; + break; case 12: m = 4095; break; diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index e2df9583f..f4502317e 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -167,7 +167,7 @@ bigint NoiseBounds::min_p0(const bigint& p1) bigint NoiseBounds::min_p1() { - return drown * B_KS + 1; + return max(bigint(drown * B_KS), bigint((phi_m * p) << 10)); } bigint NoiseBounds::opt_p1() @@ -181,8 +181,10 @@ bigint NoiseBounds::opt_p1() // solve mpf_class s = (-b + sqrt(b * b - 4 * a * c)) / (2 * a); bigint res = ceil(s); +#ifdef VERBOSE cout << "Optimal p1 vs minimal: " << numBits(res) << "/" << numBits(min_p1()) << endl; +#endif return res; } @@ -194,8 +196,10 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1) { min_p0 *= 2; min_p1 *= 2; +#ifdef VERBOSE cout << "increasing lengths: " << numBits(min_p0) << "/" << numBits(min_p1) << endl; +#endif } lg2p1 = numBits(min_p1); lg2p0 = numBits(min_p0); diff --git a/FHE/NoiseBounds.h b/FHE/NoiseBounds.h index ccd50808a..565c663ef 100644 --- a/FHE/NoiseBounds.h +++ b/FHE/NoiseBounds.h @@ -42,6 +42,8 @@ class SemiHomomorphicNoiseBounds bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); } static double min_phi_m(int log_q, double sigma); static double min_phi_m(int log_q, const FHE_Params& params); + + bigint get_B_clean() { return B_clean; } }; // as per ePrint 2012:642 for slack = 0 diff --git a/FHE/P2Data.cpp b/FHE/P2Data.cpp index 7d9a8ca47..ac4ae6f16 100644 --- a/FHE/P2Data.cpp +++ b/FHE/P2Data.cpp @@ -55,13 +55,13 @@ void P2Data::check_dimensions() const // cout << "Ai: " << Ai.size() << "x" << Ai[0].size() << endl; if (A.size() != Ai.size()) throw runtime_error("forward and backward mapping dimensions mismatch"); - if (A.size() != A[0].size()) + if (A.size() != A.at(0).size()) throw runtime_error("forward mapping not square"); - if (Ai.size() != Ai[0].size()) + if (Ai.size() != Ai.at(0).size()) throw runtime_error("backward mapping not square"); - if ((int)A[0].size() != slots * gf2n_short::degree()) + if ((int)A.at(0).size() != slots * gf2n_short::degree()) throw runtime_error( - "mapping dimension incorrect: " + to_string(A[0].size()) + "mapping dimension incorrect: " + to_string(A.at(0).size()) + " != " + to_string(slots) + " * " + to_string(gf2n_short::degree())); } diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index 84cbb9d19..4eba6e8f0 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -11,10 +11,43 @@ +template +Plaintext::Plaintext(const FHE_Params& params) : + Plaintext(params.get_plaintext_field_data(), Both) +{ +} + + +template +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).phi_m(); +} + +template +int Plaintext::degree() const +{ + return (*Field_Data).phi_m(); +} + + +template<> +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).num_slots(); +} + +template<> +int Plaintext::degree() const +{ + return (*Field_Data).degree(); +} + + template<> void Plaintext::from(const Generator& source) const { - b.resize(degree); + b.resize(degree()); for (auto& x : b) { source.get(bigint::tmp); @@ -31,7 +64,7 @@ void Plaintext::from_poly() const Ring_Element e(*Field_Data,polynomial); e.from(b); e.change_rep(evaluation); - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::from_poly() const for (unsigned int i=0; iget_prD()}; type=Both; @@ -90,7 +123,7 @@ template<> void Plaintext::from_poly() const { if (type!=Polynomial) { return; } - a.resize(n_slots); + a.resize(num_slots()); (*Field_Data).backward(a,b); type=Both; } @@ -106,34 +139,13 @@ void Plaintext::to_poly() const -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).num_slots(); - degree = (*Field_Data).degree(); -} - - template void Plaintext::allocate(PT_Type type) const { if (type != Evaluation) - b.resize(degree); + b.resize(degree()); if (type != Polynomial) - a.resize(n_slots); + a.resize(num_slots()); this->type = type; } @@ -141,7 +153,7 @@ void Plaintext::allocate(PT_Type type) const template void Plaintext::allocate_slots(const bigint& value) { - b.resize(degree); + b.resize(degree()); for (auto& x : b) x.allocate_slots(value); } @@ -236,7 +248,7 @@ void Plaintext::randomize(PRNG& G,condition cond) type=Polynomial; break; case Diagonal: - a.resize(n_slots); + a.resize(num_slots()); a[0].randomize(G); for (unsigned int i=1; i::randomize(PRNG& G,condition cond) break; default: // Gen a plaintext with 0/1 in each slot - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t) b[0].generateUniform(G, n_bits, false); } else - for (int i = 0; i < n_slots; i++) + for (size_t i = 0; i < num_slots(); i++) b[i].generateUniform(G, n_bits, false); break; default: @@ -288,7 +300,7 @@ void Plaintext::assign_zero(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::assign_one(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::equals(const Plaintext& x) const if (type!=Polynomial and x.type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::unpack(octetStream& o) unsigned int size; o.get(size); allocate(); - if (size != b.size()) + if (size != b.size() and size != 0) throw length_error("unexpected length received"); - for (unsigned int i = 0; i < b.size(); i++) + for (unsigned int i = 0; i < size; i++) b[i] = o.get(); } diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h index 52ff8b6d4..c8fb93c73 100644 --- a/FHE/Plaintext.h +++ b/FHE/Plaintext.h @@ -18,6 +18,7 @@ */ #include "FHE/Generator.h" +#include "FHE/FFT_Data.h" #include "Math/fixint.h" #include @@ -25,6 +26,8 @@ using namespace std; class FHE_PK; class Rq_Element; +class FHE_Params; +class FFT_Data; template class AddableVector; // Forward declaration as apparently this is needed for friends in templates @@ -38,13 +41,19 @@ enum condition { Full, Diagonal, Bits }; enum PT_Type { Polynomial, Evaluation, Both }; +/** + * BGV plaintext. + * Use ``Plaintext_mod_prime`` instead of filling in the templates. + * The plaintext is held in one of the two representations or both, + * polynomial and evaluation. The latter is the one allowing element-wise + * multiplication over a vector. + * Plaintexts can be added, subtracted, and multiplied via operator overloading. + */ template class Plaintext { typedef typename FD::poly_type S; - int n_slots; - int degree; mutable vector a; // The thing in evaluation/FFT form mutable vector b; // Now in polynomial form @@ -58,33 +67,47 @@ class Plaintext const FD *Field_Data; - void set_sizes(); + int degree() const; public: const FD& get_field() const { return *Field_Data; } - unsigned int num_slots() const { return n_slots; } + + /// Number of slots + unsigned int num_slots() const; Plaintext(const FD& FieldD, PT_Type type = Polynomial) - { Field_Data=&FieldD; set_sizes(); allocate(type); } + { Field_Data=&FieldD; allocate(type); } Plaintext(const FD& FieldD, const Rq_Element& other); + /// Initialization + Plaintext(const FHE_Params& params); + void allocate(PT_Type type) const; void allocate() const { allocate(type); } void allocate_slots(const bigint& value); int get_min_alloc(); - // Access evaluation representation + /** + * Read slot. + * @param i slot number + * @returns slot content + */ T element(int i) const { if (type==Polynomial) { from_poly(); } return a[i]; } + /** + * Write to slot + * @param i slot number + * @param e new slot content + */ void set_element(int i,const T& e) { if (type==Polynomial) { throw not_implemented(); } - a.resize(n_slots); + a.resize(num_slots()); a[i]=e; type=Evaluation; } @@ -171,10 +194,10 @@ class Plaintext bool is_diagonal() const; - /* Pack and unpack into an octetStream - * For unpack we assume the FFTD has been assigned correctly already - */ + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); size_t report_size(ReportType type); @@ -185,4 +208,6 @@ class Plaintext template using Plaintext_ = Plaintext; +typedef Plaintext_ Plaintext_mod_prime; + #endif diff --git a/FHE/Ring.cpp b/FHE/Ring.cpp index c1c318b8d..3b63f3069 100644 --- a/FHE/Ring.cpp +++ b/FHE/Ring.cpp @@ -24,7 +24,7 @@ void Ring::unpack(octetStream& o) o.get(pi_inv); o.get(poly); } - else + else if (mm != 0) init(*this, mm); } diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 554d4dc10..39690fa6a 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -87,7 +87,6 @@ void Ring_Element::negate() void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { - if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.element.empty()) { @@ -100,6 +99,8 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) return; } + if (a.rep!=b.rep) { throw rep_mismatch(); } + if (&ans == &a) { ans += b; diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index 531df90f7..d6a14aabd 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -5,7 +5,7 @@ #include "Math/modp.hpp" Rq_Element::Rq_Element(const FHE_PK& pk) : - Rq_Element(pk.get_params().FFTD()) + Rq_Element(pk.get_params().FFTD(), evaluation, evaluation) { } @@ -347,6 +347,12 @@ size_t Rq_Element::report_size(ReportType type) const return sz; } +void Rq_Element::unpack(octetStream& o, const FHE_Params& params) +{ + set_data(params.FFTD()); + unpack(o); +} + void Rq_Element::print_first_non_zero() const { vector v = to_vec_bigint(); diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index a58cb7de0..4e0cdf97b 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -69,8 +69,9 @@ class Rq_Element a({b0}), lev(n_mults()) {} template - Rq_Element(const FHE_Params& params, const Plaintext& plaintext) : - Rq_Element(params) + Rq_Element(const FHE_Params& params, const Plaintext& plaintext, + RepType r0 = polynomial, RepType r1 = polynomial) : + Rq_Element(params, r0, r1) { from(plaintext.get_iterator()); } @@ -159,6 +160,9 @@ class Rq_Element void pack(octetStream& o) const; void unpack(octetStream& o); + // without prior initialization + void unpack(octetStream& o, const FHE_Params& params); + void output(ostream& s) const; void input(istream& s); diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 43ad7e842..3df98c851 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -57,7 +57,7 @@ void Multiplier::multiply_and_add(Plaintext_& res, template void Multiplier::add(Plaintext_& res, const Ciphertext& c, - OT_ROLE role, int n_summands) + OT_ROLE role, int) { o.reset_write_head(); @@ -67,20 +67,10 @@ void Multiplier::add(Plaintext_& res, const Ciphertext& c, G.ReSeed(); timers["Mask randomization"].start(); product_share.randomize(G); - bigint B = 6 * machine.setup().params.get_R(); - B *= machine.setup().FieldD.get_prime(); - B <<= machine.setup().params.secp(); - // slack - B *= NonInteractiveProof::slack(machine.sec, - machine.setup().params.phi_m()); - B <<= machine.extra_slack; - B *= n_summands; - rc.generateUniform(G, 0, B, B); + mask = c; + mask.rerandomize(other_pk); timers["Mask randomization"].stop(); - timers["Encryption"].start(); - other_pk.encrypt(mask, product_share, rc); - timers["Encryption"].stop(); - mask += c; + mask += product_share; mask.pack(o); res -= product_share; } diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 59223ad03..019711829 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -75,6 +75,8 @@ void secure_init(T& setup, Player& P, U& machine, + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); + string reason; + try { ifstream file(filename); @@ -82,12 +84,30 @@ void secure_init(T& setup, Player& P, U& machine, os.input(file); os.get(machine.extra_slack); setup.unpack(os); + } + catch (exception& e) + { + reason = e.what(); + } + + try + { setup.check(P, machine); } - catch (...) + catch (exception& e) + { + reason = e.what(); + } + + if (not reason.empty()) { - cout << "Finding parameters for security " << sec << " and field size ~2^" - << plaintext_length << endl; + if (OnlineOptions::singleton.verbose) + cerr << "Generating parameters for security " << sec + << " and field size ~2^" << plaintext_length + << " because no suitable material " + "from a previous run was found (" << reason << ")" + << endl; + setup = {}; setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; diff --git a/GC/NoShare.h b/GC/NoShare.h index 49f93ac42..917e71c5e 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -50,11 +50,6 @@ class NoValue : public ValueInterface return "no"; } - static string type_short() - { - return "no"; - } - static DataFieldType field_type() { throw not_implemented(); @@ -66,7 +61,7 @@ class NoValue : public ValueInterface static void fail() { - throw runtime_error("VM does not support binary circuits"); + throw runtime_error("functionality not available"); } NoValue() {} @@ -143,6 +138,11 @@ class NoShare : public ShareInterface return 0; } + static int length() + { + return 0; + } + static void fail() { NoValue::fail(); diff --git a/Machines/dealer-ring-party.cpp b/Machines/dealer-ring-party.cpp index 4bc8fab1a..890a24ab5 100644 --- a/Machines/dealer-ring-party.cpp +++ b/Machines/dealer-ring-party.cpp @@ -5,6 +5,7 @@ #include "Protocols/DealerShare.h" #include "Protocols/DealerInput.h" +#include "Protocols/Dealer.h" #include "Processor/RingMachine.hpp" #include "Processor/Machine.hpp" @@ -12,6 +13,7 @@ #include "Protocols/DealerPrep.hpp" #include "Protocols/DealerInput.hpp" #include "Protocols/DealerMC.hpp" +#include "Protocols/DealerMatrixPrep.hpp" #include "Protocols/Beaver.hpp" #include "Semi.hpp" #include "GC/DealerPrep.h" diff --git a/Machines/mama-party.cpp b/Machines/mama-party.cpp index f270b87ce..87bf15eaa 100644 --- a/Machines/mama-party.cpp +++ b/Machines/mama-party.cpp @@ -21,5 +21,5 @@ using MamaShare_ = MamaShare; int main(int argc, const char** argv) { ez::ezOptionParser opt; - DishonestMajorityFieldMachine(argc, argv, opt); + DishonestMajorityFieldMachine(argc, argv, opt); } diff --git a/Makefile b/Makefile index 3c2be0090..03366f89d 100644 --- a/Makefile +++ b/Makefile @@ -244,6 +244,7 @@ paper-example.x: $(VM) $(OT) $(FHEOFFLINE) binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o l2h-example.x: $(VM) $(OT) Machines/Tinier.o +he-example.x: $(FHEOFFLINE) mascot-offline.x: $(VM) $(TINIER) cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index c0b2373ed..489ec5ae9 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -24,7 +24,12 @@ class FixedVec typedef T value_type; typedef FixedVec Scalar; - static const int length = L; + static const int vector_length = L; + + static int length() + { + return L * T::length(); + } static int size() { diff --git a/Math/Setup.cpp b/Math/Setup.cpp index dc76e47d7..715d480d6 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -136,7 +136,7 @@ void write_online_setup(string dirname, const bigint& p) if (mkdir_p(ss.str().c_str()) == -1) { cerr << "mkdir_p(" << ss.str() << ") failed\n"; - throw file_error(ss.str()); + throw file_error("cannot create " + dirname); } // Output the data @@ -167,6 +167,6 @@ string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, res += "-" + to_string(log2mod); res += "/"; if (mkdir_p(res.c_str()) < 0) - throw file_error(res); + throw file_error("cannot create " + res); return res; } diff --git a/Math/Z2k.h b/Math/Z2k.h index cdde3f40c..e8d2ba532 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -439,6 +439,12 @@ void Z2::randomize(PRNG& G, int n) template void Z2::randomize_part(PRNG& G, int n) { + if (n >= N_BITS) + { + randomize(G); + return; + } + *this = {}; G.get_octets((octet*)a, DIV_CEIL(n, 8)); a[DIV_CEIL(n, 64) - 1] &= mp_limb_t(-1LL) >> (N_LIMB_BITS - 1 - (n - 1) % N_LIMB_BITS); diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 876aef939..ef2f84c98 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -67,7 +67,10 @@ Z2::Z2(const IntBase& x) : template bool Z2::get_bit(int i) const { - return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); + if (i < N_BITS) + return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); + else + return false; } template diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 17fcdf24c..9dd0b7f04 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -174,7 +174,8 @@ void Zp_Data::unpack(octetStream& o) int m; o.get(m); montgomery = m; - init(pr, m); + if (pr != 0) + init(pr, m); } bool Zp_Data::operator!=(const Zp_Data& other) const diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 44e424794..d39a8593e 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -44,6 +44,19 @@ int fields_2[num_2_fields][4] = { 128, 7, 2, 1 }, }; +template +string gf2n_::options() +{ + string res = to_string(fields_2[0][0]); + for (int i = 1; i < num_2_fields; i++) + { + int n = fields_2[i][0]; + if (n <= MAX_N_BITS) + res += ", " + to_string(n); + } + return res; +} + template void gf2n_::init_tables() { @@ -113,7 +126,7 @@ void gf2n_::init_field(int nn) if (j==-1) { - throw gf2n_not_supported(nn); + throw gf2n_not_supported(nn, options()); } n=nn; diff --git a/Math/gf2n.h b/Math/gf2n.h index 3ec8849af..56377072a 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -86,6 +86,8 @@ class gf2n_ : public ValueInterface static bool allows(Dtype type) { (void) type; return true; } + static string options(); + static const true_type invertible; static const true_type characteristic_two; @@ -154,6 +156,8 @@ class gf2n_ : public ValueInterface gf2n_ operator*(int x) const { return *this * gf2n_(x); } gf2n_ invert() const; + + gf2n_ operator-() const { return *this; } void negate() { return; } /* Bitwise Ops */ diff --git a/Math/gfpvar.h b/Math/gfpvar.h index a3b475f8c..7d332fdd8 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -107,6 +107,12 @@ class gfpvar_ a = other.get(); } + template + gfpvar_(const Z2& other) : + gfpvar_(bigint(other)) + { + } + void assign(const void* buffer); void assign_zero(); diff --git a/Networking/AllButLastPlayer.h b/Networking/AllButLastPlayer.h index 22482c481..3d6d18344 100644 --- a/Networking/AllButLastPlayer.h +++ b/Networking/AllButLastPlayer.h @@ -50,17 +50,12 @@ class AllButLastPlayer : public Player void Broadcast_Receive_no_stats(vector& os) const { - vector to_send(P.num_players(), os[P.my_num()]); - vector> channels(P.num_players(), - vector(P.num_players(), true)); - for (auto& x: channels) - x.back() = false; - channels.back() = vector(P.num_players(), false); - vector to_receive; - P.send_receive_all(channels, to_send, to_receive); - for (int i = 0; i < P.num_players() - 1; i++) - if (i != P.my_num()) - os[i] = to_receive[i]; + vector senders(P.num_players(), true), receivers(P.num_players(), + true); + senders.back() = false; + receivers.back() = false; + P.partial_broadcast(senders, receivers, os); + os.resize(num_players()); } }; diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 43b2ada5c..faf8fda63 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -212,8 +212,8 @@ void CryptoPlayer::partial_broadcast(const vector& my_senders, for (int offset = 1; offset < num_players(); offset++) { int other = get_player(offset); - bool receive = my_senders[other]; - if (my_receivers[other]) + bool receive = my_senders.at(other); + if (my_receivers.at(other)) { this->senders[other]->request(os[my_num()]); sent += os[my_num()].get_length(); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index a7935f305..3a8942148 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -811,14 +811,6 @@ NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const return res; } -size_t NamedCommStats::total_data() -{ - size_t res = 0; - for (auto& x : *this) - res += x.second.data; - return res; -} - void NamedCommStats::print(bool newline) { for (auto it = begin(); it != end(); it++) diff --git a/Networking/Player.h b/Networking/Player.h index a547d4795..cf8579c0e 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -157,7 +157,6 @@ class NamedCommStats : public map NamedCommStats& operator+=(const NamedCommStats& other); NamedCommStats operator+(const NamedCommStats& other) const; NamedCommStats operator-(const NamedCommStats& other) const; - size_t total_data(); void print(bool newline = false); void reset(); #ifdef VERBOSE_COMM diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index dd7466055..3d40e2ca7 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -230,7 +230,7 @@ void Sub_Data_Files::prune() my_input_buffers.prune(); for (int j = 0; j < num_players; j++) input_buffers[j].prune(); - for (auto it : extended) + for (auto& it : extended) it.second.prune(); dabit_buffer.prune(); if (part != 0) diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 09c6e056a..2eb8a63a6 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -293,7 +293,7 @@ int InputBase::get_player(SubProcessor& Proc, int arg, bool player_from_re if (player_from_reg) { assert(Proc.Proc); - auto res = Proc.Proc->read_Ci(arg); + auto res = Proc.Proc->sync_Ci(arg); if (res >= Proc.P.num_players()) throw runtime_error("player id too large: " + to_string(res)); return res; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 5279b2584..fd91e35d3 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -13,6 +13,7 @@ using namespace std; template class Machine; template class Processor; +template class SubProcessor; class ArithmeticProcessor; class SwitchableOutput; @@ -107,6 +108,11 @@ enum CONV2DS = 0xAC, CHECK = 0xAF, PRIVATEOUTPUT = 0xAD, + // Shuffling + SECSHUFFLE = 0xFA, + GENSECSHUFFLE = 0xFB, + APPLYSHUFFLE = 0xFC, + DELSHUFFLE = 0xFD, // Data access TRIPLE = 0x50, BIT = 0x51, @@ -250,6 +256,7 @@ enum GMULS = 0x1A6, GMULRS = 0x1A7, GDOTPRODS = 0x1A8, + GSECSHUFFLE = 0x1FA, // Data access GTRIPLE = 0x150, GBIT = 0x151, @@ -388,6 +395,9 @@ class Instruction : public BaseInstruction template void print(SwitchableOutput& out, T* v, T* p = 0, T* s = 0, T* z = 0, T* nan = 0) const; + + template + typename T::clear sanitize(SubProcessor& proc, int reg) const; }; #endif diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 2a5dce70c..5bed37037 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -157,6 +157,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case LISTEN: case CLOSECLIENTCONNECTION: case CRASH: + case DELSHUFFLE: r[0]=get_int(s); break; // instructions with 2 registers + 1 integer operand @@ -203,6 +204,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case DIGESTC: case INPUTMASK: case GINPUTMASK: + case SECSHUFFLE: + case GSECSHUFFLE: get_ints(r, s, 2); n = get_int(s); break; @@ -230,6 +233,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case CONDPRINTSTR: case CONDPRINTSTRB: case RANDOMS: + case GENSECSHUFFLE: r[0]=get_int(s); n = get_int(s); break; @@ -269,6 +273,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // instructions with 5 register operands case PRINTFLOATPLAIN: case PRINTFLOATPLAINB: + case APPLYSHUFFLE: get_vector(5, start, s); break; case INCINT: @@ -558,6 +563,7 @@ int BaseInstruction::get_reg_type() const case CONVCBITVEC: case INTOUTPUT: case ACCEPTCLIENTCONNECTION: + case GENSECSHUFFLE: return INT; case PREP: case GPREP: @@ -835,11 +841,13 @@ inline void Instruction::execute(Processor& Proc) const { for (int i = 0; i < size; i++) Proc.write_Ci(r[0] + i, - Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get()); + Proc.sync( + Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get())); } else if (n <= 64) for (int i = 0; i < size; i++) - Proc.write_Ci(r[0] + i, Integer(Proc.read_Cp(r[1] + i), n).get()); + Proc.write_Ci(r[0] + i, + Proc.sync(Integer(Proc.read_Cp(r[1] + i), n).get())); else throw Processor_Error(to_string(n) + "-bit conversion impossible; " "integer registers only have 64 bits"); @@ -856,40 +864,32 @@ inline void Instruction::execute(Processor& Proc) const n++; break; case LDMCI: - Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1]))); + Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.sync_Ci(r[1]))); break; case STMC: Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0])); n++; break; case STMCI: - Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0])); + Proc.machine.Mp.write_C(Proc.sync_Ci(r[1]), Proc.read_Cp(r[0])); break; case MOVC: Proc.write_Cp(r[0],Proc.read_Cp(r[1])); break; case DIVC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); - Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / Proc.read_Cp(r[2])); + Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2])); break; case GDIVC: - if (Proc.read_C2(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); - Proc.write_C2(r[0], Proc.read_C2(r[1]) / Proc.read_C2(r[2])); + Proc.write_C2(r[0], Proc.read_C2(r[1]) / sanitize(Proc.Proc2, r[2])); break; case FLOORDIVC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); Proc.temp.aa.from_signed(Proc.read_Cp(r[1])); - Proc.temp.aa2.from_signed(Proc.read_Cp(r[2])); + Proc.temp.aa2.from_signed(sanitize(Proc.Procp, r[2])); Proc.write_Cp(r[0], bigint(Proc.temp.aa / Proc.temp.aa2)); break; case MODC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Modulo by zero from register"); to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); - to_bigint(Proc.temp.aa2, Proc.read_Cp(r[2])); + to_bigint(Proc.temp.aa2, sanitize(Proc.Procp, r[2])); mpz_fdiv_r(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t(), Proc.temp.aa2.get_mpz_t()); Proc.temp.ansp.convert_destroy(Proc.temp.aa); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -948,7 +948,7 @@ inline void Instruction::execute(Processor& Proc) const Procp.protocol.randoms_inst(Procp.get_S(), *this); return; case INPUTMASKREG: - Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2])); + Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.sync_Ci(r[2])); Proc.write_Cp(r[1], Proc.temp.rrp); break; case INPUTMASK: @@ -1034,7 +1034,7 @@ inline void Instruction::execute(Processor& Proc) const return; case MATMULSM: Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this, - Proc.read_Ci(r[1]), Proc.read_Ci(r[2])); + Proc.sync_Ci(r[1]), Proc.sync_Ci(r[2])); return; case CONV2DS: Proc.Procp.protocol.conv2ds(Proc.Procp, *this); @@ -1042,6 +1042,21 @@ inline void Instruction::execute(Processor& Proc) const case TRUNC_PR: Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp); return; + case SECSHUFFLE: + Proc.Procp.secure_shuffle(*this); + return; + case GSECSHUFFLE: + Proc.Proc2.secure_shuffle(*this); + return; + case GENSECSHUFFLE: + Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this)); + return; + case APPLYSHUFFLE: + Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3))); + return; + case DELSHUFFLE: + Proc.Procp.delete_shuffle(Proc.read_Ci(r[0])); + return; case CHECK: { CheckJob job; @@ -1056,14 +1071,14 @@ inline void Instruction::execute(Processor& Proc) const Proc.PC += (signed int) n; break; case JMPI: - Proc.PC += (signed int) Proc.read_Ci(r[0]); + Proc.PC += (signed int) Proc.sync_Ci(r[0]); break; case JMPNZ: - if (Proc.read_Ci(r[0]) != 0) + if (Proc.sync_Ci(r[0]) != 0) { Proc.PC += (signed int) n; } break; case JMPEQZ: - if (Proc.read_Ci(r[0]) == 0) + if (Proc.sync_Ci(r[0]) == 0) { Proc.PC += (signed int) n; } break; case PRINTREG: @@ -1123,7 +1138,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.join_tape(r[0]); break; case CRASH: - if (Proc.read_Ci(r[0])) + if (Proc.sync_Ci(r[0])) throw crash_requested(); break; case STARTGRIND: @@ -1146,7 +1161,7 @@ inline void Instruction::execute(Processor& Proc) const // *** case LISTEN: // listen for connections at port number n - Proc.external_clients.start_listening(Proc.read_Ci(r[0])); + Proc.external_clients.start_listening(Proc.sync_Ci(r[0])); break; case ACCEPTCLIENTCONNECTION: { @@ -1335,4 +1350,15 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c out << "]"; } +template +typename T::clear Instruction::sanitize(SubProcessor& proc, int reg) const +{ + if (not T::real_shares(proc.P)) + return 1; + auto& res = proc.get_C_ref(reg); + if (res.is_zero()) + throw Processor_Error("Division by zero from register"); + return res; +} + #endif diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index e0299c2f3..ce90e1b27 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -30,7 +30,7 @@ void Machine::init_binary_domains(int security_parameter, int lg2) if (not is_same()) { - if (sgf2n::clear::degree() < security_parameter) + if (sgf2n::mac_key_type::length() < security_parameter) { cerr << "Security parameter needs to be at most n in GF(2^n)." << endl; @@ -469,7 +469,10 @@ void Machine::run(const string& progname) for (auto& x : comm_stats) rounds += x.second.rounds; cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds - << " rounds (party " << my_number << ")" << endl; + << " rounds (party " << my_number; + if (threads.size() > 1) + cerr << "; rounds counted double due to multi-threading"; + cerr << ")" << endl; auto& P = *this->P; Bundle bundle(P); diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index d4c66e9aa..85ee25d0b 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -36,7 +36,9 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - ("Bit length of GF(2^n) field (default: " + to_string(V::default_degree()) + ")").c_str(), // Help description. + ("Bit length of GF(2^n) field (default: " + + to_string(V::default_degree()) + "; options are " + + V::options() + ")").c_str(), // Help description. "-lg2", // Flag token. "--lg2" // Flag token. ); diff --git a/Processor/Processor.h b/Processor/Processor.h index 38ea7f258..927e93279 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -20,6 +20,7 @@ #include "Tools/CheckVector.h" #include "GC/Processor.h" #include "GC/ShareThread.h" +#include "Protocols/SecureShuffle.h" class Program; @@ -31,6 +32,8 @@ class SubProcessor DataPositions bit_usage; + SecureShuffle shuffler; + void resize(size_t size) { C.resize(size); S.resize(size); } template friend class Processor; @@ -70,6 +73,11 @@ class SubProcessor size_t b); void conv2ds(const Instruction& instruction); + void secure_shuffle(const Instruction& instruction); + size_t generate_secure_shuffle(const Instruction& instruction); + void apply_shuffle(const Instruction& instruction, int handle); + void delete_shuffle(int handle); + void input_personal(const vector& args); void send_personal(const vector& args); void private_output(const vector& args); @@ -127,6 +135,10 @@ class ArithmeticProcessor : public ProcessorBase ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num), sent(0), rounds(0), opts(opts) {} + virtual ~ArithmeticProcessor() + { + } + bool use_stdin() { return thread_num == 0 and opts.interactive; @@ -146,6 +158,11 @@ class ArithmeticProcessor : public ProcessorBase CheckVector& get_Ci() { return Ci; } + virtual long sync_Ci(size_t) const + { + throw not_implemented(); + } + void shuffle(const Instruction& instruction); void bitdecint(const Instruction& instruction); }; @@ -241,6 +258,10 @@ class Processor : public ArithmeticProcessor cint get_inverse2(unsigned m); + // synchronize in asymmetric protocols + long sync_Ci(size_t i) const; + long sync(long x) const; + private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d74594b3d..861e8cfe0 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -9,6 +9,7 @@ #include "Processor/ProcessorBase.hpp" #include "GC/Processor.hpp" #include "GC/ShareThread.hpp" +#include "Protocols/SecureShuffle.hpp" #include #include @@ -23,6 +24,7 @@ SubProcessor::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& template SubProcessor::SubProcessor(typename T::MAC_Check& MC, Preprocessing& DataF, Player& P, ArithmeticProcessor* Proc) : + shuffler(*this), Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC), bit_prep(bit_usage) { @@ -340,6 +342,9 @@ void Processor::read_socket_private(int client_id, // Tolerent to no file if no shares yet persisted. template void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { + if (not sint::real_shares(P)) + return; + string filename; filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; @@ -370,6 +375,9 @@ template void Processor::write_shares_to_file(long start_pos, const vector& data_registers) { + if (not sint::real_shares(P)) + return; + string filename = binary_file_io.filename(P.my_num()); unsigned int size = data_registers.size(); @@ -633,6 +641,33 @@ void SubProcessor::conv2ds(const Instruction& instruction) } } +template +void SubProcessor::secure_shuffle(const Instruction& instruction) +{ + SecureShuffle(S, instruction.get_size(), instruction.get_n(), + instruction.get_r(0), instruction.get_r(1), *this); +} + +template +size_t SubProcessor::generate_secure_shuffle(const Instruction& instruction) +{ + return shuffler.generate(instruction.get_n()); +} + +template +void SubProcessor::apply_shuffle(const Instruction& instruction, int handle) +{ + shuffler.apply(S, instruction.get_size(), instruction.get_start()[2], + instruction.get_start()[0], instruction.get_start()[1], handle, + instruction.get_start()[4]); +} + +template +void SubProcessor::delete_shuffle(int handle) +{ + shuffler.del(handle); +} + template void SubProcessor::input_personal(const vector& args) { @@ -690,4 +725,25 @@ typename sint::clear Processor::get_inverse2(unsigned m) return inverses2m[m]; } +template +long Processor::sync_Ci(size_t i) const +{ + return sync(read_Ci(i)); +} + +template +long Processor::sync(long x) const +{ + if (not sint::symmetric) + { + // send number to dealer + if (P.my_num() == 0) + P.send_long(P.num_players() - 1, x); + if (not sint::real_shares(P)) + return P.receive_long(0); + } + + return x; +} + #endif diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index 626942212..8527f98f7 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -50,7 +50,10 @@ RingMachine::RingMachine(int argc, const char** argv, case L: \ machine.template run, V>(); \ break; - X(64) X(72) X(128) X(192) + X(64) +#ifndef FEWER_RINGS + X(72) X(128) X(192) +#endif #ifdef RING_SIZE X(RING_SIZE) #endif diff --git a/Programs/Source/dijkstra_example.mpc b/Programs/Source/dijkstra_example.mpc new file mode 100644 index 000000000..950fe331f --- /dev/null +++ b/Programs/Source/dijkstra_example.mpc @@ -0,0 +1,50 @@ +# example code for graph with vertices 0,1,2 and with following weights +# 0 -> 1: 5 +# 0 -> 2: 20 +# 1 -> 2: 10 + +# output should be the following +# from 0 to 0 at cost 0 via vertex 0 +# from 0 to 1 at cost 5 via vertex 0 +# from 0 to 2 at cost 15 via vertex 1 + +from oram import OptimalORAM +from dijkstra import dijkstra + +# structure for edges +# contains tuples of form (neighbor, cost, last neighbor bit) +edges = OptimalORAM(4, # number of edges + entry_size=(2, # enough bits for vertices + 5, # enough bits for costs + 1) # always one +) + +# first edge from vertex 0 +edges[0] = (1, 5, 0) +# second and last edge from vertex 0 +edges[1] = (2, 20, 1) +# edge from vertex 1 +edges[2] = (2, 10, 1) +# dummy edge from vertex 2 to itself +edges[3] = (2, 0, 1) + +# structure assigning edge list indices to vertices +e_index = OptimalORAM(3, # number vertices + entry_size=2) # enough bits for edge indices + +# edges from 0 start at 0 +e_index[0] = 0 +# edges from 1 start at 2 +e_index[1] = 2 +# edges from 2 start at 3 +e_index[2] = 3 + +source = sint(0) + +res = dijkstra(source, edges, e_index, OptimalORAM) + +@for_range(res.size) +def _(i): + import util + print_ln('from %s to %s at cost %s via vertex %s', source.reveal(), i, + res[i][0].reveal(), res[i][1].reveal()) diff --git a/Programs/Source/dijkstra_tutorial.mpc b/Programs/Source/dijkstra_tutorial.mpc deleted file mode 100644 index 7ab220237..000000000 --- a/Programs/Source/dijkstra_tutorial.mpc +++ /dev/null @@ -1,9 +0,0 @@ -import dijkstra -from path_oram import OptimalORAM - -n = 1000 - -dist = dijkstra.test_dijkstra_on_cycle(n, OptimalORAM) - -for i in range(n): - print_ln('%s: %s', i, dist[i][0].reveal()) diff --git a/Protocols/Dealer.h b/Protocols/Dealer.h new file mode 100644 index 000000000..cc2c45baf --- /dev/null +++ b/Protocols/Dealer.h @@ -0,0 +1,36 @@ +/* + * Dealer.h + * + */ + +#ifndef PROTOCOLS_DEALER_H_ +#define PROTOCOLS_DEALER_H_ + +#include "Beaver.h" + +template +class Dealer : public Beaver +{ + SeededPRNG G; + +public: + Dealer(Player& P) : + Beaver(P) + { + } + + T get_random() + { + if (T::real_shares(this->P)) + return G.get(); + else + return {}; + } + + vector get_relevant_players() + { + return vector(1, this->P.num_players() - 1); + } +}; + +#endif /* PROTOCOLS_DEALER_H_ */ diff --git a/Protocols/DealerInput.h b/Protocols/DealerInput.h index 7d0699da4..7f0a26dd5 100644 --- a/Protocols/DealerInput.h +++ b/Protocols/DealerInput.h @@ -24,6 +24,7 @@ class DealerInput : public InputBase DealerInput(SubProcessor& proc, typename T::MAC_Check&); DealerInput(typename T::MAC_Check&, Preprocessing&, Player& P); DealerInput(Player& P); + DealerInput(SubProcessor*, Player& P); ~DealerInput(); bool is_dealer(int player = -1); diff --git a/Protocols/DealerInput.hpp b/Protocols/DealerInput.hpp index 26bfb9a1a..8b1ea855a 100644 --- a/Protocols/DealerInput.hpp +++ b/Protocols/DealerInput.hpp @@ -10,7 +10,7 @@ template DealerInput::DealerInput(SubProcessor& proc, typename T::MAC_Check&) : - DealerInput(proc.P) + DealerInput(&proc, proc.P) { } @@ -23,6 +23,13 @@ DealerInput::DealerInput(typename T::MAC_Check&, Preprocessing&, template DealerInput::DealerInput(Player& P) : + DealerInput(0, P) +{ +} + +template +DealerInput::DealerInput(SubProcessor* proc, Player& P) : + InputBase(proc), P(P), to_send(P), shares(P.num_players()), from_dealer(false), sub_player(P) { @@ -68,8 +75,8 @@ void DealerInput::add_mine(const typename T::open_type& input, if (is_dealer()) { make_share(shares.data(), input, P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - shares.at(i - 1).pack(to_send[i]); + for (int i = 0; i < P.num_players() - 1; i++) + shares.at(i).pack(to_send[i]); from_dealer = true; } else diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h index 5311f8132..4e6681366 100644 --- a/Protocols/DealerMC.h +++ b/Protocols/DealerMC.h @@ -25,6 +25,7 @@ class DealerMC : public MAC_Check_Base void prepare_open(const T& secret); void exchange(const Player& P); typename T::open_type finalize_raw(); + array finalize_several(int n); DealerMC& get_part_MC() { diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp index a9ddc035c..0f63b93dc 100644 --- a/Protocols/DealerMC.hpp +++ b/Protocols/DealerMC.hpp @@ -73,4 +73,11 @@ typename T::open_type DealerMC::finalize_raw() return {}; } +template +array DealerMC::finalize_several(int n) +{ + assert(sub_player); + return internal.finalize_several(n); +} + #endif /* PROTOCOLS_DEALERMC_HPP_ */ diff --git a/Protocols/DealerMatrixPrep.h b/Protocols/DealerMatrixPrep.h new file mode 100644 index 000000000..787397255 --- /dev/null +++ b/Protocols/DealerMatrixPrep.h @@ -0,0 +1,32 @@ +/* + * DealerMatrixPrep.h + * + */ + +#ifndef PROTOCOLS_DEALERMATRIXPREP_H_ +#define PROTOCOLS_DEALERMATRIXPREP_H_ + +#include "ShareMatrix.h" + +template +class DealerMatrixPrep : public BufferPrep> +{ + typedef BufferPrep> super; + typedef typename T::LivePrep LivePrep; + + int n_rows, n_inner, n_cols; + + LivePrep* prep; + +public: + DealerMatrixPrep(int n_rows, int n_inner, int n_cols, + typename T::LivePrep&, DataPositions& usage); + + void set_protocol(typename ShareMatrix::Protocol&) + { + } + + void buffer_triples(); +}; + +#endif /* PROTOCOLS_DEALERMATRIXPREP_H_ */ diff --git a/Protocols/DealerMatrixPrep.hpp b/Protocols/DealerMatrixPrep.hpp new file mode 100644 index 000000000..faf98ec77 --- /dev/null +++ b/Protocols/DealerMatrixPrep.hpp @@ -0,0 +1,87 @@ +/* + * DealerMatrixPrep.hpp + * + */ + +#include "DealerMatrixPrep.h" + +template +DealerMatrixPrep::DealerMatrixPrep(int n_rows, int n_inner, int n_cols, + typename T::LivePrep& prep, DataPositions& usage) : + super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), + prep(&prep) +{ +} + +template +void append_shares(vector& os, + ValueMatrix& M, PRNG& G) +{ + size_t n = os.size(); + for (auto& value : M.entries) + { + T sum; + for (size_t i = 0; i < n - 2; i++) + { + auto share = G.get(); + sum += share; + share.pack(os[i]); + } + (value - sum).pack(os[n - 2]); + } +} + +template +ShareMatrix receive_shares(octetStream& o, int n, int m) +{ + ShareMatrix res(n, m); + for (size_t i = 0; i < res.entries.size(); i++) + res.entries.v.push_back(o.get()); + return res; +} + +template +void DealerMatrixPrep::buffer_triples() +{ + assert(this->prep); + assert(this->prep->proc); + auto& P = this->prep->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + int batch_size = 100; + if (not T::real_shares(P)) + { + SeededPRNG G; + ValueMatrix A(n_rows, n_inner), B(n_inner, n_cols), + C(n_rows, n_cols); + for (int i = 0; i < P.num_players() - 1; i++) + os[i].reserve( + batch_size * T::size() + * (A.entries.size() + B.entries.size() + + C.entries.size())); + for (int i = 0; i < batch_size; i++) + { + A.randomize(G); + B.randomize(G); + C = A * B; + append_shares(os, A, G); + append_shares(os, B, G); + append_shares(os, C, G); + this->triples.push_back({{{n_rows, n_inner}, {n_inner, n_cols}, + {n_rows, n_cols}}}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < batch_size; i++) + { + auto& o = to_receive.back(); + this->triples.push_back({{receive_shares(o, n_rows, n_inner), + receive_shares(o, n_inner, n_cols), + receive_shares(o, n_rows, n_cols)}}); + } + } +} diff --git a/Protocols/DealerPrep.h b/Protocols/DealerPrep.h index ae28ec691..417fdbac7 100644 --- a/Protocols/DealerPrep.h +++ b/Protocols/DealerPrep.h @@ -11,6 +11,13 @@ template class DealerPrep : virtual public BitPrep { + friend class DealerMatrixPrep; + + template + void buffer_inverses(true_type); + template + void buffer_inverses(false_type); + template void buffer_edabits(int n_bits, true_type); template @@ -23,8 +30,14 @@ class DealerPrep : virtual public BitPrep } void buffer_triples(); + void buffer_inverses(); void buffer_bits(); + void buffer_inputs(int player) + { + this->buffer_inputs_as_usual(player, this->proc); + } + void buffer_dabits(ThreadQueues* = 0); void buffer_edabits(int n_bits, ThreadQueues*); void buffer_sedabits(int n_bits, ThreadQueues*); diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index d4a0a91dd..cc010dd71 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -45,6 +45,57 @@ void DealerPrep::buffer_triples() } } +template +void DealerPrep::buffer_inverses() +{ + buffer_inverses(T::invertible); +} + +template +template +void DealerPrep::buffer_inverses(false_type) +{ + throw not_implemented(); +} + +template +template +void DealerPrep::buffer_inverses(true_type) +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T tuple[2]; + while (tuple[0] == 0) + tuple[0] = G.get(); + tuple[1] = tuple[0].invert(); + for (auto& value : tuple) + { + make_share(shares.data(), typename T::clear(value), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + } + this->inverses.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->inverses.push_back(to_receive.back().get>().get()); + } +} + template void DealerPrep::buffer_bits() { diff --git a/Protocols/DealerShare.h b/Protocols/DealerShare.h index 38900ff37..e59e19494 100644 --- a/Protocols/DealerShare.h +++ b/Protocols/DealerShare.h @@ -13,12 +13,16 @@ template class DealerPrep; template class DealerInput; template class DealerMC; template class DirectDealerMC; +template class DealerMatrixPrep; +template class Hemi; namespace GC { class DealerSecret; } +template class Dealer; + template class DealerShare : public SemiShare { @@ -30,22 +34,26 @@ class DealerShare : public SemiShare typedef DealerMC MAC_Check; typedef DirectDealerMC Direct_MC; - typedef Beaver Protocol; + typedef Hemi Protocol; typedef DealerInput Input; typedef DealerPrep LivePrep; typedef ::PrivateOutput PrivateOutput; + typedef DealerMatrixPrep MatrixPrep; + typedef Dealer BasicProtocol; + static false_type dishonest_majority; const static bool needs_ot = false; + const static bool symmetric = false; static string type_short() { return "DD" + string(1, T::type_char()); } - static int threshold(int) + static bool real_shares(const Player& P) { - throw runtime_error("undefined threshold"); + return P.my_num() != P.num_players() - 1; } static This constant(const T& other, int my_num, diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index c0a269d1a..e5bb9e9e5 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -33,6 +33,7 @@ class FakeShare : public T, public ShareInterface static const bool has_trunc_pr = true; static const bool dishonest_majority = false; + static const bool malicious = false; static string type_short() { diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index f43260ea1..0aa61bcba 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -13,22 +13,24 @@ * Matrix multiplication optimized with semi-homomorphic encryption */ template -class Hemi : public Semi +class Hemi : public T::BasicProtocol { - map, HemiMatrixPrep*> matrix_preps; + map, typename T::MatrixPrep*> matrix_preps; DataPositions matrix_usage; + MatrixMC mc; + ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor); public: Hemi(Player& P) : - Semi(P) + T::BasicProtocol(P) { } ~Hemi(); - HemiMatrixPrep& get_matrix_prep(const array& dimensions, + typename T::MatrixPrep& get_matrix_prep(const array& dimensions, SubProcessor& processor); void matmulsm(SubProcessor& processor, CheckVector& source, diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 1b3d8f5ba..1549e2cf4 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -21,12 +21,12 @@ Hemi::~Hemi() } template -HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, +typename T::MatrixPrep& Hemi::get_matrix_prep(const array& dims, SubProcessor& processor) { if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, - new HemiMatrixPrep(dims[0], dims[1], dims[2], + new typename T::MatrixPrep(dims[0], dims[1], dims[2], dynamic_cast(processor.DataF), matrix_usage)}); return *matrix_preps.at(dims); @@ -52,22 +52,27 @@ void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); - for (int k = 0; k < dim[1]; k++) + if (not T::real_shares(processor.P)) { - for (int i = 0; i < dim[0]; i++) + matrix_multiply(A, B, processor); + return; + } + + for (int i = 0; i < dim[0]; i++) + for (int k = 0; k < dim[1]; k++) { auto kk = Proc->get_Ci().at(dim[4] + k); auto ii = Proc->get_Ci().at(dim[3] + i); - A[{i, k}] = source.at(a + ii * dim[7] + kk); + A.entries.v.push_back(source.at(a + ii * dim[7] + kk)); } + for (int k = 0; k < dim[1]; k++) for (int j = 0; j < dim[2]; j++) { auto jj = Proc->get_Ci().at(dim[6] + j); auto ll = Proc->get_Ci().at(dim[5] + k); - B[{k, j}] = source.at(b + ll * dim[8] + jj); + B.entries.v.push_back(source.at(b + ll * dim[8] + jj)); } - } auto res = matrix_multiply(A, B, processor); @@ -94,13 +99,16 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, subdim[1] = min(max_inner, A.n_cols - i); subdim[2] = min(max_cols, B.n_cols - j); auto& prep = get_matrix_prep(subdim, processor); - MatrixMC mc; beaver.init(prep, mc); beaver.init_mul(); - beaver.prepare_mul(A.from(0, i, subdim.data()), - B.from(i, j, subdim.data() + 1)); - beaver.exchange(); - C.add_from_col(j, beaver.finalize_mul()); + bool for_real = T::real_shares(processor.P); + beaver.prepare_mul(A.from(0, i, subdim.data(), for_real), + B.from(i, j, subdim.data() + 1, for_real)); + if (for_real) + { + beaver.exchange(); + C.add_from_col(j, beaver.finalize_mul()); + } } } @@ -150,6 +158,15 @@ void Hemi::conv2ds(SubProcessor& processor, array dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}}); ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); + if (not T::real_shares(processor.P)) + { + matrix_multiply(A, B, processor); + return; + } + + A.entries.init(); + B.entries.init(); + for (int i_batch = 0; i_batch < batch_size; i_batch ++) { size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in; diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index 4a85cbe34..ddf7e186f 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -10,6 +10,7 @@ template class HemiPrep; template class Hemi; +template class HemiMatrixPrep; template class HemiShare : public SemiShare @@ -26,6 +27,9 @@ class HemiShare : public SemiShare typedef typename conditional, Beaver>::type Protocol; typedef HemiPrep LivePrep; + typedef HemiMatrixPrep MatrixPrep; + typedef Semi BasicProtocol; + static const bool needs_ot = false; static const bool local_mul = true; static true_type triple_matmul; diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 19d5e72d5..fccd2ef57 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -298,7 +298,8 @@ void TreeSum::start(vector& values, const Player& P) { // send from the root player os.reset_write_head(); - for (unsigned int i=0; i::~Direct_MAC_Check() { template void direct_add_openings(vector& values, const PlayerBase& P, vector& os) { - for (unsigned int i=0; i(); } template diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index e855214fd..1f745251e 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -13,6 +13,7 @@ using namespace std; #include "Tools/PointerVector.h" template class Preprocessing; +template class MatrixMC; /** * Abstract base class for opening protocols @@ -20,6 +21,8 @@ template class Preprocessing; template class MAC_Check_Base { + friend class MatrixMC; + protected: /* MAC Share */ typename T::mac_key_type::Scalar alphai; @@ -59,6 +62,7 @@ class MAC_Check_Base /// Get next opened value virtual typename T::clear finalize_open(); virtual typename T::open_type finalize_raw(); + array finalize_several(size_t n); /// Check whether all ``shares`` are ``value`` virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 59c6c5dec..47528e006 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -70,6 +70,13 @@ typename T::open_type MAC_Check_Base::finalize_raw() return values.next(); } +template +array MAC_Check_Base::finalize_several(size_t n) +{ + assert(values.left() >= n); + return {{values.skip(0), values.skip(n)}}; +} + template void MAC_Check_Base::CheckFor(const typename T::open_type& value, const vector& shares, const Player& P) diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index e6f3a8a6a..7c94b5d81 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -42,8 +42,12 @@ class MaliciousRep3Share : public Rep3Share typedef GC::MaliciousRepSecret bit_type; + // indicate security relevance of field size + typedef T mac_key_type; + const static bool expensive = true; static const bool has_trunc_pr = false; + static const bool malicious = true; static string type_short() { diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 17eec6f11..631ef7667 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -160,7 +160,7 @@ template void CommMaliciousRepMC::POpen_Begin(vector& values, const vector& S, const Player& P) { - assert(T::length == 2); + assert(T::vector_length == 2); (void)values; os.resize(2); for (auto& o : os) diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index ceedc9157..332996ddd 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -45,6 +45,8 @@ class MaliciousShamirShare : public ShamirShare typedef GC::MaliciousCcdSecret bit_type; #endif + static const bool malicious = true; + static string type_short() { return "M" + super::type_short(); diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index afb456621..786276974 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -122,6 +122,7 @@ class Rep3Share : public RepShare const static bool expensive = false; const static bool variable_players = false; static const bool has_trunc_pr = true; + static const bool malicious = false; static string type_short() { diff --git a/Protocols/Rep4Share.h b/Protocols/Rep4Share.h index 5e197804d..7befb7f4f 100644 --- a/Protocols/Rep4Share.h +++ b/Protocols/Rep4Share.h @@ -37,6 +37,8 @@ class Rep4Share : public RepShare typedef GC::Rep4Secret bit_type; + static const bool malicious = true; + static string type_short() { return "R4" + string(1, T::type_char()); diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index ba5b85c8e..48b014408 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -121,6 +121,8 @@ class ProtocolBase virtual void cisc(SubProcessor&, const Instruction&) { throw runtime_error("CISC instructions not implemented"); } + + virtual vector get_relevant_players(); }; /** @@ -146,7 +148,7 @@ class Replicated : public ReplicatedBase, public ProtocolBase static void assign(T& share, const typename T::clear& value, int my_num) { - assert(T::length == 2); + assert(T::vector_length == 2); share.assign_zero(); if (my_num < 2) share[my_num] = value; diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 2d9eba572..f398da7fe 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -28,7 +28,7 @@ ProtocolBase::ProtocolBase() : template Replicated::Replicated(Player& P) : ReplicatedBase(P) { - assert(T::length == 2); + assert(T::vector_length == 2); } template @@ -152,6 +152,16 @@ T ProtocolBase::get_random() return res; } +template +vector ProtocolBase::get_relevant_players() +{ + vector res; + int n = dynamic_cast(*this).P.num_players(); + for (int i = 0; i < T::threshold(n) + 1; i++) + res.push_back(i); + return res; +} + template void Replicated::init_mul() { diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 9bb3c30a3..9e1498df0 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -71,7 +71,7 @@ class ReplicatedInput : public PrepLessInput ReplicatedInput(SubProcessor* proc, Player& P) : PrepLessInput(proc), proc(proc), P(P), protocol(P) { - assert(T::length == 2); + assert(T::vector_length == 2); expect.resize(P.num_players()); this->reset_all(P); } diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index e72c0d839..4d875a3b2 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -28,7 +28,7 @@ void ReplicatedMC::POpen_Begin(vector&, template void ReplicatedMC::prepare(const vector& S) { - assert(T::length == 2); + assert(T::vector_length == 2); o.reset_write_head(); to_send.reset_write_head(); to_send.reserve(S.size() * T::value_type::size()); diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h new file mode 100644 index 000000000..a90c6e64f --- /dev/null +++ b/Protocols/SecureShuffle.h @@ -0,0 +1,53 @@ +/* + * SecureShuffle.h + * + */ + +#ifndef PROTOCOLS_SECURESHUFFLE_H_ +#define PROTOCOLS_SECURESHUFFLE_H_ + +#include +using namespace std; + +template class SubProcessor; + +template +class SecureShuffle +{ + SubProcessor& proc; + vector to_shuffle; + vector> config; + vector tmp; + int unit_size; + + vector>>> shuffles; + size_t n_shuffle; + bool exact; + + void player_round(int config_player); + void generate(int config_player, int n_shuffle); + + void waksman(vector& a, int depth, int start); + void cond_swap(T& x, T& y, const T& b); + + void iter_waksman(bool reverse = false); + void waksman_round(int size, bool inwards, bool reverse); + + void pre(vector& a, size_t n, size_t input_base); + void post(vector& a, size_t n, size_t input_base); + +public: + SecureShuffle(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc); + + SecureShuffle(SubProcessor& proc); + + int generate(int n_shuffle); + + void apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse); + + void del(int handle); +}; + +#endif /* PROTOCOLS_SECURESHUFFLE_H_ */ diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp new file mode 100644 index 000000000..d2b0676ac --- /dev/null +++ b/Protocols/SecureShuffle.hpp @@ -0,0 +1,328 @@ +/* + * SecureShuffle.hpp + * + */ + +#ifndef PROTOCOLS_SECURESHUFFLE_HPP_ +#define PROTOCOLS_SECURESHUFFLE_HPP_ + +#include "SecureShuffle.h" +#include "Tools/Waksman.h" + +#include +#include + +template +SecureShuffle::SecureShuffle(SubProcessor& proc) : + proc(proc), unit_size(0), n_shuffle(0), exact(false) +{ +} + +template +SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc) : + proc(proc), unit_size(unit_size) +{ + pre(a, n, input_base); + + for (auto i : proc.protocol.get_relevant_players()) + player_round(i); + + post(a, n, output_base); +} + +template +void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse) +{ + this->unit_size = unit_size; + + pre(a, n, input_base); + + auto& shuffle = shuffles.at(handle); + assert(shuffle.size() == proc.protocol.get_relevant_players().size()); + + if (reverse) + for (auto it = shuffle.end(); it > shuffle.begin(); it--) + { + this->config = *(it - 1); + iter_waksman(reverse); + } + else + for (auto& config : shuffle) + { + this->config = config; + iter_waksman(reverse); + } + + post(a, n, output_base); +} + +template +void SecureShuffle::del(int handle) +{ + shuffles.at(handle).clear(); +} + +template +void SecureShuffle::pre(vector& a, size_t n, size_t input_base) +{ + n_shuffle = n / unit_size; + assert(unit_size * n_shuffle == n); + size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle)))); + exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious; + to_shuffle.clear(); + + if (exact) + { + to_shuffle.resize(n_shuffle_pow2 * unit_size); + for (size_t i = 0; i < n; i++) + to_shuffle[i] = a[input_base + i]; + } + else + { + // sorting power of two elements together with indicator bits + to_shuffle.resize((unit_size + 1) << int(ceil(log2(n_shuffle)))); + for (size_t i = 0; i < n_shuffle; i++) + { + for (int j = 0; j < unit_size; j++) + to_shuffle[i * (unit_size + 1) + j] = a[input_base + + i * unit_size + j]; + to_shuffle[i * (unit_size + 1) + unit_size] = T::constant(1, + proc.P.my_num(), proc.MC.get_alphai()); + } + this->unit_size++; + } +} + +template +void SecureShuffle::post(vector& a, size_t n, size_t output_base) +{ + if (exact) + for (size_t i = 0; i < n; i++) + a[output_base + i] = to_shuffle[i]; + else + { + auto& MC = proc.MC; + MC.init_open(proc.P); + int shuffle_unit_size = this->unit_size; + int unit_size = shuffle_unit_size - 1; + for (size_t i = 0; i < to_shuffle.size() / shuffle_unit_size; i++) + MC.prepare_open(to_shuffle.at((i + 1) * shuffle_unit_size - 1)); + MC.exchange(proc.P); + size_t i_shuffle = 0; + for (size_t i = 0; i < n_shuffle; i++) + { + auto bit = MC.finalize_open(); + if (bit == 1) + { + // only output real elements + for (int j = 0; j < unit_size; j++) + a.at(output_base + i_shuffle * unit_size + j) = + to_shuffle.at(i * shuffle_unit_size + j); + i_shuffle++; + } + } + if (i_shuffle != n_shuffle) + throw runtime_error("incorrect shuffle"); + } +} + +template +void SecureShuffle::player_round(int config_player) +{ + generate(config_player, n_shuffle); + iter_waksman(); +} + +template +int SecureShuffle::generate(int n_shuffle) +{ + int res = shuffles.size(); + shuffles.push_back({}); + auto& shuffle = shuffles.back(); + + for (auto i : proc.protocol.get_relevant_players()) + { + generate(i, n_shuffle); + shuffle.push_back(config); + } + + return res; +} + +template +void SecureShuffle::generate(int config_player, int n) +{ + auto& P = proc.P; + auto& input = proc.input; + input.reset_all(P); + int n_pow2 = 1 << int(ceil(log2(n))); + Waksman waksman(n_pow2); + + if (P.my_num() == config_player) + { + vector perm; + int shuffle_size = n; + for (int j = 0; j < n_pow2; j++) + perm.push_back(j); + SeededPRNG G; + for (int i = 0; i < shuffle_size; i++) + { + int j = G.get_uint(shuffle_size - i); + swap(perm[i], perm[i + j]); + } + + auto config_bits = waksman.configure(perm); + for (size_t i = 0; i < config_bits.size(); i++) + { + auto& x = config_bits[i]; + for (size_t j = 0; j < x.size(); j++) + if (waksman.matters(i, j)) + input.add_mine(int(x[j])); + else + assert(x[j] == 0); + } + } + else + for (size_t i = 0; i < waksman.n_bits(); i++) + input.add_other(config_player); + + input.exchange(); + config.clear(); + typename T::Protocol checker(P); + checker.init(proc.DataF, proc.MC); + checker.init_dotprod(); + auto one = T::constant(1, P.my_num(), proc.MC.get_alphai()); + for (size_t i = 0; i < waksman.n_rounds(); i++) + { + config.push_back({}); + for (int j = 0; j < n_pow2; j++) + { + if (waksman.matters(i, j)) + { + config.back().push_back(input.finalize(config_player)); + if (T::malicious) + checker.prepare_dotprod(config.back().back(), + one - config.back().back()); + } + else + config.back().push_back({}); + } + } + + if (T::malicious) + { + checker.next_dotprod(); + checker.exchange(); + assert( + typename T::clear( + proc.MC.open(checker.finalize_dotprod(waksman.n_bits()), + P)) == 0); + checker.check(); + } +} + +template +void SecureShuffle::waksman(vector& a, int depth, int start) +{ + int n = a.size(); + + if (n == 2) + { + cond_swap(a[0], a[1], config.at(depth).at(start)); + return; + } + + vector a0(n / 2), a1(n / 2); + for (int i = 0; i < n / 2; i++) + { + a0.at(i) = a.at(2 * i); + a1.at(i) = a.at(2 * i + 1); + + cond_swap(a0[i], a1[i], config.at(depth).at(i + start + n / 2)); + } + + waksman(a0, depth + 1, start); + waksman(a1, depth + 1, start + n / 2); + + for (int i = 0; i < n / 2; i++) + { + a.at(2 * i) = a0.at(i); + a.at(2 * i + 1) = a1.at(i); + cond_swap(a[2 * i], a[2 * i + 1], config.at(depth).at(i + start)); + } +} + +template +void SecureShuffle::cond_swap(T& x, T& y, const T& b) +{ + auto diff = proc.protocol.mul(x - y, b); + x -= diff; + y += diff; +} + +template +void SecureShuffle::iter_waksman(bool reverse) +{ + int n = to_shuffle.size() / unit_size; + + for (int depth = 0; depth < log2(n); depth++) + waksman_round(depth, true, reverse); + + for (int depth = log2(n) - 2; depth >= 0; depth--) + waksman_round(depth, false, reverse); +} + +template +void SecureShuffle::waksman_round(int depth, bool inwards, bool reverse) +{ + int n = to_shuffle.size() / unit_size; + assert((int) config.at(depth).size() == n); + int nblocks = 1 << depth; + int size = n / (2 * nblocks); + bool outwards = !inwards; + proc.protocol.init_mul(); + vector> indices; + indices.reserve(n / 2); + Waksman waksman(n); + for (int k = 0; k < n / 2; k++) + { + int j = k % size; + int i = k / size; + int base = 2 * i * size; + int in1 = base + j + j * inwards; + int in2 = in1 + inwards + size * outwards; + int out1 = base + j + j * outwards; + int out2 = out1 + outwards + size * inwards; + int i_bit = base + j + size * (outwards ^ reverse); + bool run = waksman.matters(depth, i_bit); + if (run) + { + for (int l = 0; l < unit_size; l++) + proc.protocol.prepare_mul(config.at(depth).at(i_bit), + to_shuffle.at(in1 * unit_size + l) + - to_shuffle.at(in2 * unit_size + l)); + } + indices.push_back({{in1, in2, out1, out2, run}}); + } + proc.protocol.exchange(); + tmp.resize(to_shuffle.size()); + for (int k = 0; k < n / 2; k++) + { + auto idx = indices.at(k); + for (int l = 0; l < unit_size; l++) + { + T diff; + if (idx[4]) + diff = proc.protocol.finalize_mul(); + tmp.at(idx[2] * unit_size + l) = to_shuffle.at( + idx[0] * unit_size + l) - diff; + tmp.at(idx[3] * unit_size + l) = to_shuffle.at( + idx[1] * unit_size + l) + diff; + } + } + swap(tmp, to_shuffle); +} + +#endif /* PROTOCOLS_SECURESHUFFLE_HPP_ */ diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index b306d5c3d..432b599bb 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -78,6 +78,7 @@ class SemiShare : public T, public ShareInterface const static bool variable_players = true; const static bool expensive = false; static const bool has_trunc_pr = true; + static const bool malicious = false; static string type_short() { return "D" + string(1, T::type_char()); } diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index aea0bb97a..bf40cb287 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -49,6 +49,7 @@ class ShamirShare : public T, public ShareInterface const static bool dishonest_majority = false; const static bool variable_players = true; const static bool expensive = false; + const static bool malicious = true; static string type_short() { diff --git a/Protocols/Share.h b/Protocols/Share.h index e2a9f0bb5..9ca86cea7 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -56,6 +56,7 @@ class Share_ : public ShareInterface const static bool dishonest_majority = T::dishonest_majority; const static bool variable_players = T::variable_players; const static bool has_mac = true; + static const bool malicious = true; static int size() { return T::size() + V::size(); } diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index e5af8dddd..a8ef7a224 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -40,12 +40,17 @@ class ShareInterface static const bool has_trunc_pr = false; static const bool has_split = false; static const bool has_mac = false; + static const bool malicious = false; static const false_type triple_matmul; + const static bool symmetric = true; + static const int default_length = 1; - static string type_short() { return "undef"; } + static string type_short() { throw runtime_error("don't call this"); } + + static bool real_shares(const Player&) { return true; } template static void split(vector, vector, int, T*, int, @@ -63,6 +68,8 @@ class ShareInterface template static void generate_mac_key(T&, U&) {} + + static int threshold(int) { throw runtime_error("undefined threshold"); } }; #endif /* PROTOCOLS_SHAREINTERFACE_H_ */ diff --git a/Protocols/ShareMatrix.h b/Protocols/ShareMatrix.h index 7f84213e6..b31aa7085 100644 --- a/Protocols/ShareMatrix.h +++ b/Protocols/ShareMatrix.h @@ -14,6 +14,124 @@ using namespace std; template class MatrixMC; +template +class NonInitVector +{ + template friend class NonInitVector; + + size_t size_; +public: + AddableVector v; + + NonInitVector(size_t size) : + size_(size) + { + v.reserve(size); + } + + template + NonInitVector(const NonInitVector& other) : + size_(other.size()), v(other.v) + { + } + + size_t size() const + { + return size_; + } + + void init() + { + v.resize(size_); + } + + void check() const + { +#ifdef DEBUG_MATRIX + assert(not v.empty()); +#endif + } + + typename vector::iterator begin() + { + check(); + return v.begin(); + } + + typename vector::iterator end() + { + check(); + return v.end(); + } + + T& at(size_t index) + { + check(); + return v.at(index); + } + + const T& at(size_t index) const + { +#ifdef DEBUG_MATRIX + assert(index < size()); +#endif + return (*this)[index]; + } + + T& operator[](size_t index) + { + check(); + return v[index]; + } + + const T& operator[](size_t index) const + { + check(); + return v[index]; + } + + NonInitVector operator-(const NonInitVector& other) const + { + assert(size() == other.size()); + NonInitVector res(size()); + if (other.v.empty()) + return *this; + else if (v.empty()) + { + res.init(); + res.v = res.v - other.v; + } + else + res.v = v - other.v; + return res; + } + + NonInitVector& operator+=(const NonInitVector& other) + { + assert(size() == other.size()); + if (not other.v.empty()) + { + if (v.empty()) + *this = other; + else + v += other.v; + } + return *this; + } + + bool operator!=(const NonInitVector& other) const + { + return v != other.v; + } + + void randomize(PRNG& G) + { + v.clear(); + for (size_t i = 0; i < size(); i++) + v.push_back(G.get()); + } +}; + template class ValueMatrix : public ValueInterface { @@ -21,7 +139,7 @@ class ValueMatrix : public ValueInterface public: int n_rows, n_cols; - AddableVector entries; + NonInitVector entries; static DataFieldType field_type() { @@ -48,15 +166,19 @@ class ValueMatrix : public ValueInterface T& operator[](const pair& indices) { +#ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); +#endif return entries.at(indices.first * n_cols + indices.second); } const T& operator[](const pair& indices) const { +#ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); +#endif return entries.at(indices.first * n_cols + indices.second); } @@ -80,6 +202,9 @@ class ValueMatrix : public ValueInterface { assert(n_cols == other.n_rows); This res(n_rows, other.n_cols); + if (entries.v.empty() or other.entries.v.empty()) + return res; + res.entries.init(); for (int i = 0; i < n_rows; i++) for (int j = 0; j < other.n_cols; j++) for (int k = 0; k < n_cols; k++) @@ -103,9 +228,9 @@ class ValueMatrix : public ValueInterface ValueMatrix transpose() const { ValueMatrix res(this->n_cols, this->n_rows); - for (int i = 0; i < this->n_rows; i++) - for (int j = 0; j < this->n_cols; j++) - res[{j, i}] = (*this)[{i, j}]; + for (int j = 0; j < this->n_cols; j++) + for (int i = 0; i < this->n_rows; i++) + res.entries.v.push_back((*this)[{i, j}]); return res; } @@ -139,7 +264,7 @@ class ShareMatrix : public ValueMatrix, public ShareInterface { This res(other.n_rows, other.n_cols); for (size_t i = 0; i < other.entries.size(); i++) - res.entries[i] = T::constant(other.entries[i], my_num, key); + res.entries.v.push_back(T::constant(other.entries[i], my_num, key)); res.check(); return res; } @@ -167,24 +292,29 @@ class ShareMatrix : public ValueMatrix, public ShareInterface ShareMatrix from_col(int start, int size) const { ShareMatrix res(this->n_rows, min(size, this->n_cols - start)); + res.entries.clear(); for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) - res[{i, j}] = (*this)[{i, start + j}]; + res.entries.v.push_back((*this)[{i, start + j}]); return res; } - ShareMatrix from(int start_row, int start_col, int* sizes) const + ShareMatrix from(int start_row, int start_col, int* sizes, bool for_real = + true) const { ShareMatrix res(min(sizes[0], this->n_rows - start_row), min(sizes[1], this->n_cols - start_col)); + if (not for_real) + return res; for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) - res[{i, j}] = (*this)[{start_row + i, start_col + j}]; + res.entries.v.push_back((*this)[{start_row + i, start_col + j}]); return res; } void add_from_col(int start, const ShareMatrix& other) { + this->entries.init(); for (int i = 0; i < this->n_rows; i++) for (int j = 0; j < other.n_cols; j++) (*this)[{i, start + j}] += other[{i, j}]; @@ -197,6 +327,9 @@ ShareMatrix operator*(const ValueMatrix& a, { assert(a.n_cols == b.n_rows); ShareMatrix res(a.n_rows, b.n_cols); + if (a.entries.v.empty() or b.entries.v.empty()) + return res; + res.entries.init(); for (int i = 0; i < a.n_rows; i++) for (int j = 0; j < b.n_cols; j++) for (int k = 0; k < a.n_cols; k++) @@ -208,9 +341,22 @@ ShareMatrix operator*(const ValueMatrix& a, template class MatrixMC : public MAC_Check_Base> { - typename T::MAC_Check inner; + typename T::MAC_Check& inner; public: + MatrixMC() : + inner( + *(OnlineOptions::singleton.direct ? + new typename T::Direct_MC : + new typename T::MAC_Check)) + { + } + + ~MatrixMC() + { + delete &inner; + } + void exchange(const Player& P) { inner.init_open(P); @@ -224,8 +370,15 @@ class MatrixMC : public MAC_Check_Base> for (auto& share : this->secrets) { this->values.push_back({share.n_rows, share.n_cols}); - for (auto& entry : this->values.back().entries) - entry = inner.finalize_open(); + if (share.entries.v.empty()) + for (size_t i = 0; i < share.entries.size(); i++) + inner.finalize_open(); + else + { + auto range = inner.finalize_several(share.entries.size()); + auto& v = this->values.back().entries.v; + v.insert(v.begin(), range[0], range[1]); + } } } }; diff --git a/Protocols/TemiShare.h b/Protocols/TemiShare.h index f4f37dcd6..049881ffe 100644 --- a/Protocols/TemiShare.h +++ b/Protocols/TemiShare.h @@ -25,6 +25,9 @@ class TemiShare : public HemiShare typedef typename conditional, Beaver>::type Protocol; typedef TemiPrep LivePrep; + typedef HemiMatrixPrep MatrixPrep; + typedef Semi BasicProtocol; + static const bool needs_ot = false; static const bool local_mul = false; diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index bae415c4a..63b058e08 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -130,7 +130,6 @@ template void make_share(DealerShare* Sa, const T& a, int N, const U&, PRNG& G) { make_share((SemiShare*) Sa, a, N - 1, U(), G); - Sa[N - 1] = {}; } template @@ -273,6 +272,11 @@ inline string mac_filename(string directory, int playerno) + to_string(playerno); } +template <> +inline void write_mac_key(const string&, int, int, GC::NoValue) +{ +} + template void write_mac_key(const string& directory, int i, int nplayers, U key) { @@ -301,6 +305,11 @@ void read_mac_key(const string& directory, const Names& N, T& key) read_mac_key(directory, N.my_num(), N.num_players(), key); } +template <> +inline void read_mac_key(const string&, int, int, GC::NoValue&) +{ +} + template void read_mac_key(const string& directory, int player_num, int nplayers, U& key) { @@ -367,7 +376,7 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P, } template -void read_global_mac_key(const string& directory, int nparties, U& key, false_type) +void read_global_mac_key(const string& directory, int nparties, U& key) { U pp; key.assign_zero(); @@ -383,15 +392,9 @@ void read_global_mac_key(const string& directory, int nparties, U& key, false_ty cout << "Final Keys : " << key << endl; } -template -void read_global_mac_key(const string&, int, U&, true_type) -{ -} - -template -void read_global_mac_key(const string& directory, int nparties, U& key) +template <> +inline void read_global_mac_key(const string&, int, GC::NoValue&) { - read_global_mac_key(directory, nparties, key, is_same()); } template @@ -579,14 +582,14 @@ void plain_edabits(vector& as, as.resize(max_size); bs.clear(); bs.resize(length); - bigint value; + Z2 value; for (int j = 0; j < max_size; j++) { if (not zero) - G.get_bigint(value, length, true); + value.randomize_part(G, length); as[j] = value; for (int k = 0; k < length; k++) - bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; + bs[k] ^= BitVec(value.get_bit(k)) << j; } } diff --git a/README.md b/README.md index a3f6741fd..cd0e9781c 100644 --- a/README.md +++ b/README.md @@ -101,8 +101,9 @@ The following table lists all protocols that are fully supported. | Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | -| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | +| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep3 / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +| Malicious, honest supermajority | [Rep4](#honest-majority) | [Rep4](#honest-majority) | [Rep4](#honest-majority) | N/A | | Semi-honest, dealer | [Dealer](#dealer-model) | [Dealer](#dealer-model) | [Dealer](#dealer-model) | N/A | Modulo prime and modulo 2^k are the two settings that allow @@ -280,6 +281,8 @@ compute the preprocessing time for a particular computation. - Python 3.5 or later - NTL library for homomorphic encryption (optional; tested with NTL 10.5) - If using macOS, Sierra or later + - Windows/VirtualBox: see [this + issue](https://github.com/data61/MP-SPDZ/issues/557) for a discussion #### Compilation diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index c7f8c371f..ec39b7728 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -84,7 +84,9 @@ not_enough_to_buffer::not_enough_to_buffer(const string& type, const string& fil { } -gf2n_not_supported::gf2n_not_supported(int n) : - runtime_error("GF(2^" + to_string(n) + ") not supported") +gf2n_not_supported::gf2n_not_supported(int n, string options) : + runtime_error( + "GF(2^" + to_string(n) + ") not supported" + + (options.empty() ? "" : ", options are " + options)) { } diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index bb347c6a8..a3ca3a5d0 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -281,7 +281,7 @@ class insufficient_memory : public runtime_error class gf2n_not_supported : public runtime_error { public: - gf2n_not_supported(int n); + gf2n_not_supported(int n, string options = ""); }; #endif diff --git a/Tools/PointerVector.h b/Tools/PointerVector.h index 32d1b46ee..404c4ee92 100644 --- a/Tools/PointerVector.h +++ b/Tools/PointerVector.h @@ -30,6 +30,15 @@ class PointerVector : public CheckVector { return (*this)[i++]; } + T* skip(size_t n) + { + i += n; + return &(*this)[i]; + } + size_t left() + { + return this->size() - i; + } }; #endif /* TOOLS_POINTERVECTOR_H_ */ diff --git a/Tools/Waksman.cpp b/Tools/Waksman.cpp new file mode 100644 index 000000000..a54b7766d --- /dev/null +++ b/Tools/Waksman.cpp @@ -0,0 +1,91 @@ +/* + * Waksman.cpp + * + */ + +#include "Waksman.h" + +#include +#include +#include + +template +void append(vector& x, const vector& y) +{ + x.insert(x.end(), y.begin(), y.end()); +} + +vector > Waksman::configure(const vector& perm) +{ + int n = perm.size(); + assert(n > 1); + + if (n == 2) + return {{perm[0] == 1, perm[0] == 1}}; + + vector I(n / 2); + vector O(n / 2, -1); + vector p0(n / 2, -1), p1(n / 2, -1), inv_perm(n); + + for (int i = 0; i < n; i++) + inv_perm[perm[i]] = i; + + while (true) + { + auto it = find(O.begin(), O.end(), -1); + if (it == O.end()) + break; + int j = 2 * (it - O.begin()); + O.at(j / 2) = 0; + int j0 = j; + + while (true) + { + int i = inv_perm.at(j); + p0.at(i / 2) = j / 2; + I.at(i / 2) = i % 2; + O.at(j / 2) = j % 2; + if (i % 2 == 1) + i--; + else + i++; + j = perm.at(i); + if (j % 2 == 1) + j--; + else + j++; + p1.at(i / 2) = perm.at(i) / 2; + if (j == j0) + break; + } + + if ((find(p1.begin(), p1.end(), -1) == p1.end()) + and (find(p0.begin(), p0.end(), -1) == p0.end())) + break; + } + + auto p0_config = configure(p0); + auto p1_config = configure(p1); + + vector> res; + res.push_back(I); + for (auto& x : O) + res.back().push_back(x); + + assert(p0_config.size() == p1_config.size()); + + for (size_t i = 0; i < p0_config.size(); i++) + { + res.push_back(p0_config.at(i)); + append(res.back(), p1_config.at(i)); + } + + assert(res.size() == Waksman(perm.size()).n_rounds()); + return res; +} + +Waksman::Waksman(int n_elements) : + n_elements(n_elements), nr(log2(n_elements)) +{ + assert(n_elements == (1 << nr)); +} diff --git a/Tools/Waksman.h b/Tools/Waksman.h new file mode 100644 index 000000000..521e990f9 --- /dev/null +++ b/Tools/Waksman.h @@ -0,0 +1,39 @@ +/* + * Waksman.h + * + */ + +#ifndef TOOLS_WAKSMAN_H_ +#define TOOLS_WAKSMAN_H_ + +#include +using namespace std; + +class Waksman +{ + int n_elements; + int nr; + +public: + static vector> configure(const vector& perm); + + Waksman(int n_elements); + + size_t n_rounds() const + { + return nr; + } + + bool matters(int i, int j) const + { + int block = n_elements >> i; + return block == 2 or j % block != block / 2; + } + + size_t n_bits() const + { + return nr * n_elements - (1 << (nr - 1)) + 1; + } +}; + +#endif /* TOOLS_WAKSMAN_H_ */ diff --git a/Utils/he-example.cpp b/Utils/he-example.cpp new file mode 100644 index 000000000..179028a5b --- /dev/null +++ b/Utils/he-example.cpp @@ -0,0 +1,97 @@ +/* + * he-example.cpp + * + */ + +#include "FHE/FHE_Params.h" +#include "FHE/NTL-Subs.h" +#include "FHE/FHE_Keys.h" +#include "FHE/Plaintext.h" + +void first_phase(string filename, int n_mults, int circuit_sec); +void second_phase(string filename); + +int main() +{ + for (int n_mults = 0; n_mults < 2; n_mults++) + for (int sec = 0; sec <= 120; sec += 40) + { + string filename = "mp-spdz-he"; + first_phase(filename, n_mults, sec); + second_phase(filename); + } +} + +void first_phase(string filename, int n_mults, int circuit_sec) +{ + // specify number of multiplications (at most one) and function privacy parameter + // increase the latter to accommodate more operations + FHE_Params params(n_mults, circuit_sec); + + // generate parameters for computation modulo a 32-bit prime + params.basic_generation_mod_prime(32); + + // find computation modulus (depends on parameter generation) + cout << "computation modulo " << params.get_plaintext_modulus() << endl; + + // generate key pair + FHE_KeyPair pair(params); + pair.generate(); + + Plaintext_mod_prime plaintext(params); + + // set first two plaintext slots + plaintext.set_element(0, 4); + plaintext.set_element(1, -1); + + // encrypt + Ciphertext ciphertext = pair.pk.encrypt(plaintext); + + // store for second phase + octetStream os; + params.pack(os); + pair.pk.pack(os); + ciphertext.pack(os); + plaintext.pack(os); + pair.sk.pack(os); + ofstream out(filename); + os.output(out); +} + +void second_phase(string filename) +{ + // read from file + ifstream in(filename); + octetStream os; + os.input(in); + FHE_Params params; + FHE_PK pk(params); + FHE_SK sk(params); + Plaintext_mod_prime plaintext(params); + Ciphertext ciphertext(params); + + // parameter must be set correctly first + params.unpack(os); + pk.unpack(os); + ciphertext.unpack(os); + plaintext.unpack(os); + + if (params.n_mults() == 0) + // public-private multiplication is always available + ciphertext *= plaintext; + else + // private-private multiplication only with matching parameters + ciphertext = ciphertext.mul(pk, ciphertext); + + // re-randomize for circuit privacy + ciphertext.rerandomize(pk); + + // read secret key and decrypt + sk.unpack(os); + plaintext = sk.decrypt(ciphertext); + + cout << "should be 16: " << plaintext.element(0) << endl; + cout << "should be 1: " << plaintext.element(1) << endl; + assert(plaintext.element(0) == 16); + assert(plaintext.element(1) == 1); +} diff --git a/doc/Doxyfile b/doc/Doxyfile index 9820ba50c..5f1143e31 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h ../FHE/Ciphertext.h ../FHE/FHE_Keys.h ../FHE/FHE_Params.h ../FHE/Plaintext.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/homomorphic-encryption.rst b/doc/homomorphic-encryption.rst new file mode 100644 index 000000000..95c922fd8 --- /dev/null +++ b/doc/homomorphic-encryption.rst @@ -0,0 +1,31 @@ +Homomorphic Encryption +---------------------- + +MP-SPDZ uses BGV encryption for triple generation in a number of +protocols. This involves zero-knowledge proofs in some protocols and +considerations about function privacy in all of them. The interface +described below allows directly accessing the basic cryptographic +operations in contexts where these considerations are not relevant. +See ``Utils/he-example.cpp`` for some example code. + + +Reference +~~~~~~~~~ + +.. doxygenclass:: FHE_Params + :members: + +.. doxygenclass:: FHE_KeyPair + :members: + +.. doxygenclass:: FHE_SK + :members: + +.. doxygenclass:: FHE_PK + :members: + +.. doxygenclass:: Plaintext + :members: + +.. doxygenclass:: Ciphertext + :members: diff --git a/doc/index.rst b/doc/index.rst index d2a2c4dcd..59caa58de 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -175,6 +175,7 @@ Reference non-linear preprocessing add-protocol + homomorphic-encryption troubleshooting diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 268084806..8b02fa3ed 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -148,12 +148,14 @@ AVX/AVX2 instructions are deactivated (see e.g. `here `_), which causes a dramatic performance loss. Deactivate Hyper-V/Hypervisor using:: + bcdedit /set hypervisorlaunchtype off DISM /Online /Disable-Feature:Microsoft-Hyper-V Performance can be further increased when compiling MP-SPDZ yourself: :: + sudo apt-get update sudo apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm git clone https://github.com/data61/MP-SPDZ.git From 88534961b3492b7804f2de0d8425f5ee0b401bdb Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 2 Jun 2022 17:12:11 +0200 Subject: [PATCH 069/221] Fix biases in PRNG. --- Tools/random.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Tools/random.cpp b/Tools/random.cpp index 7cf1924f3..94f97cf60 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -179,10 +179,11 @@ unsigned int PRNG::get_uint(int upper) } // not power of 2 unsigned int r, reduced; + bool use_char = upper <= 128; do { - r = (upper < 255) ? get_uchar() : get_uint(); + r = use_char ? get_uchar() : get_uint(); reduced = r % upper; - } while (int(r - reduced + (upper - 1)) < 0); + } while (int(r - reduced + (upper - 1)) > (use_char ? 256 : 0)); return reduced; } @@ -260,7 +261,9 @@ void PRNG::get_bigint(bigint& res, int n_bits, bool positive) octet* bytes = (octet*) words; words[n_words - 1] = 0; get_octets(bytes, n_bytes); - octet mask = (1 << (n_bits % 8)) - 1; + octet mask = -1; + if (n_bits % 8 > 0) + mask = (1 << (n_bits % 8)) - 1; bytes[n_bytes - 1] &= mask; mpz_import(res.get_mpz_t(), n_words, -1, sizeof(word), -1, 0, bytes); if (not positive and (get_bit())) From 6755a8fa5105be5d196f2a8b71e2609fad564611 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 13 Jun 2022 11:20:55 +0200 Subject: [PATCH 070/221] Python client example. --- ExternalIO/README.md | 10 ++- ExternalIO/bankers-bonus-client.py | 35 +++++++++ ExternalIO/client.py | 113 +++++++++++++++++++++++++++++ ExternalIO/domains.py | 67 +++++++++++++++++ 4 files changed, 223 insertions(+), 2 deletions(-) create mode 100755 ExternalIO/bankers-bonus-client.py create mode 100644 ExternalIO/client.py create mode 100644 ExternalIO/domains.py diff --git a/ExternalIO/README.md b/ExternalIO/README.md index d4f99288b..f5f418ed9 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -2,12 +2,13 @@ The ExternalIO directory contains an example of managing I/O between external cl ## Working Examples -[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a +[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) and +[bankers-bonus-client.py](./bankers-bonus-client.py) act as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output as described by [Damgård et al.](https://eprint.iacr.org/2015/1006) The computation allows up to eight clients to input a number and computes the client -with the largest input. You can run it as follows from the main +with the largest input. You can run the C++ code as follows from the main directory: ``` make bankers-bonus-client.x @@ -30,6 +31,11 @@ protocol script. The setup scripts generate the necessary SSL certificates and keys. Therefore, if you run the computation on different hosts, you will have to distribute the `*.pem` files. +For the Python client, make sure to install +[gmpy2](https://pypi.org/project/gmpy2), and run +`ExternalIO/bankers-bonus-client.py` instead of +`bankers-bonus-client.x`. + ## I/O MPC Instructions ### Connection Setup diff --git a/ExternalIO/bankers-bonus-client.py b/ExternalIO/bankers-bonus-client.py new file mode 100755 index 000000000..d0f8d285b --- /dev/null +++ b/ExternalIO/bankers-bonus-client.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 + +import sys + +sys.path.append('.') + +from client import * +from domains import * + +client_id = int(sys.argv[1]) +n_parties = int(sys.argv[2]) +bonus = float(sys.argv[3]) +finish = int(sys.argv[4]) + +client = Client(['localhost'] * n_parties, 14000, client_id) + +type = client.specification.get_int(4) + +if type == ord('R'): + domain = Z2(client.specification.get_int(4)) +elif type == ord('p'): + domain = Fp(client.specification.get_bigint()) +else: + raise Exception('invalid type') + +for socket in client.sockets: + os = octetStream() + os.store(finish) + os.Send(socket) + +for x in bonus, bonus * 2 ** 16: + client.send_private_inputs([domain(x)]) + + print('Winning client id is :', + client.receive_outputs(domain, 1)[0].v % 2 ** 64) diff --git a/ExternalIO/client.py b/ExternalIO/client.py new file mode 100644 index 000000000..819647ff3 --- /dev/null +++ b/ExternalIO/client.py @@ -0,0 +1,113 @@ +import socket, ssl +import struct +import time + +class Client: + def __init__(self, hostnames, port_base, my_client_id): + ctx = ssl.SSLContext() + name = 'C%d' % my_client_id + prefix = 'Player-Data/%s' % name + ctx.load_cert_chain(certfile=prefix + '.pem', keyfile=prefix + '.key') + ctx.load_verify_locations(capath='Player-Data') + + self.sockets = [] + for i, hostname in enumerate(hostnames): + for j in range(10000): + try: + plain_socket = socket.create_connection( + (hostname, port_base + i)) + break + except ConnectionRefusedError: + if j < 60: + time.sleep(1) + else: + raise + octetStream(b'%d' % my_client_id).Send(plain_socket) + self.sockets.append(ctx.wrap_socket(plain_socket, + server_hostname='P%d' % i)) + + self.specification = octetStream() + self.specification.Receive(self.sockets[0]) + + def receive_triples(self, T, n): + triples = [[0, 0, 0] * n] + os = octetStream() + for socket in self.sockets: + os.Receive(socket) + for triple in triples: + for i in range(3): + t = T() + t.unpack(os) + triple[i] += t + res = [] + for triple in triples: + prod = triple[0] * triple[1] + if prod != triple[2]: + raise Exception( + 'invalid triple, diff %s' % hex(prod.v - triple[2].v)) + return triples + + def send_private_inputs(self, values): + T = type(values[0]) + triples = self.receive_triples(T, len(values)) + os = octetStream() + for value, triple in zip(values, triples): + (value + triple[0]).pack(os) + for socket in self.sockets: + os.Send(socket) + + def receive_outputs(self, T, n): + triples = self.receive_triples(T, n) + return [triple[0] for triple in triples] + +class octetStream: + def __init__(self, value=None): + self.buf = b'' + self.ptr = 0 + if value is not None: + self.buf += value + + def reset_write_head(self): + self.buf = b'' + self.ptr = 0 + + def Send(self, socket): + socket.send(struct.pack('= 0) + + def __add__(self, other): + try: + res = self.v + other.v + except: + res = self.v + other + return type(self)(res) + + def __mul__(self, other): + try: + res = self.v * other.v + except: + res = self.v * other + return type(self)(res) + + __radd__ = __add__ + + def __eq__(self, other): + return self.v == other.v + + def __neq__(self, other): + return self.v != other.v + + def unpack(self, os): + self.v = 0 + buf = os.consume(self.n_bytes) + for i, b in enumerate(buf): + self.v += b << (i * 8) + + def pack(self, os): + v = self.v + for i in range(self.n_bytes): + os.buf += struct.pack('>= 8 + +def Z2(k): + class Z(Domain): + modulus = 2 ** k + n_words = (k + 63) // 64 + n_bytes = (k + 7) // 8 + + return Z + +def Fp(mod): + import gmpy2 + + class Fp(Domain): + modulus = mod + n_words = (modulus.bit_length() + 63) // 64 + n_bytes = 8 * n_words + R = 2 ** (64 * n_words) % modulus + R_inv = gmpy2.invert(R, modulus) + + def unpack(self, os): + Domain.unpack(self, os) + self.v = self.v * self.R_inv % self.modulus + + def pack(self, os): + Domain.pack(type(self)(self.v * self.R), os) + + return Fp From 4c8e616b58710ddc118d5dee1de6ac7be41c908c Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 14 Jun 2022 16:14:37 +0200 Subject: [PATCH 071/221] Improved binary circuit functionality. --- Compiler/GC/types.py | 35 +++++++++++++++++++++++++++++++---- Compiler/types.py | 3 ++- Compiler/util.py | 5 +++++ Processor/Instruction.hpp | 1 + 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index fdd987225..4287844a2 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -661,6 +661,9 @@ def malloc(size, creator_tape=None): return sbit.malloc(size * n, creator_tape=creator_tape) @staticmethod def n_elements(): + return 1 + @staticmethod + def mem_size(): return n @classmethod def get_input_from(cls, player): @@ -692,22 +695,33 @@ def __init__(self, other=None, size=None): self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n @classmethod - def load_mem(cls, address): + def load_mem(cls, address, size=None): + if size not in (None, 1): + assert isinstance(address, int) or len(address) == 1 + sb = sbits.get_type(size) + return cls.from_vec(sb.bit_compose( + sbit.load_mem(address + i + j * n) for j in range(size)) + for i in range(n)) if not isinstance(address, int) and len(address) == n: return cls.from_vec(sbit.load_mem(x) for x in address) else: return cls.from_vec(sbit.load_mem(address + i) for i in range(n)) def store_in_mem(self, address): + size = 1 for x in self.v: - assert util.is_constant(x) or x.n == 1 - v = [sbit.conv(x) for x in self.v] + if not util.is_constant(x): + size = max(size, x.n) + v = [sbits.get_type(size).conv(x) for x in self.v] if not isinstance(address, int) and len(address) == n: + assert max_n == 1 for x, y in zip(v, address): x.store_in_mem(y) else: + assert isinstance(address, int) or len(address) == 1 for i in range(n): - v[i].store_in_mem(address + i) + for j, x in enumerate(v[i].bit_decompose()): + x.store_in_mem(address + i + j * n) def reveal(self): if len(self) > cbits.unit: return self.elements()[0].reveal() @@ -861,6 +875,19 @@ def bit_xor(self, other): return self ^ other def right_shift(self, m, k, security=None, signed=True): return self.from_vec(self.v[m:]) + def tree_reduce(self, function): + elements = self.elements() + while len(elements) > 1: + size = len(elements) + half = size // 2 + left = elements[:half] + right = elements[half:2*half] + odd = elements[2*half:] + sides = [self.from_vec(sbitvec(x).v) for x in (left, right)] + red = function(*sides) + elements = red.elements() + elements += odd + return self.from_vec(sbitvec(elements).v) class bit(object): n = 1 diff --git a/Compiler/types.py b/Compiler/types.py index 098f493f0..735fddea2 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5701,7 +5701,8 @@ def __getitem__(self, index): self.sub_cache[key] = \ Array(self.sizes[1], self.value_type, \ self.address + index * self.sizes[1] * - self.value_type.n_elements(), \ + self.value_type.n_elements() * \ + self.value_type.mem_size(), \ debug=self.debug) else: self.sub_cache[key] = \ diff --git a/Compiler/util.py b/Compiler/util.py index aa491e422..9d84df226 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -116,6 +116,11 @@ def round_to_int(x): return x.round_to_int() def tree_reduce(function, sequence): + try: + return sequence.tree_reduce(function) + except AttributeError: + pass + sequence = list(sequence) assert len(sequence) > 0 n = len(sequence) diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 5bed37037..a0f7a490a 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -730,6 +730,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case ANDM: case NOTS: case NOTCB: + case TRANS: size = DIV_CEIL(n, 64); break; case CONVCBIT2S: From ec1d302b03bb8e747ef3ce51d7a04e50c2c8f796 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 23 Jun 2022 14:42:54 +0200 Subject: [PATCH 072/221] Local right shift for GF(2^n). --- Compiler/instructions.py | 1 + Compiler/instructions_base.py | 4 ++-- Compiler/types.py | 24 +++++++++++++++--------- Processor/Instruction.h | 5 +++-- Processor/Instruction.hpp | 4 ++++ Protocols/Rep3Share.h | 2 +- Protocols/Semi2kShare.h | 11 ----------- Protocols/SemiShare.h | 25 +++++++++++++++++++++++++ 8 files changed, 51 insertions(+), 25 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 5f5b82dbc..aac0c34cf 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1051,6 +1051,7 @@ class shrci(base.ClearShiftInstruction): code = base.opcodes['SHRCI'] op = '__rshift__' +@base.gf2n @base.vectorize class shrsi(base.ClearShiftInstruction): """ Bitwise right shift of secret register (vector) by (constant) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index d598d8a71..3a56e6043 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -207,8 +207,8 @@ CONDPRINTPLAIN = 0xE1, INTOUTPUT = 0xE6, FLOATOUTPUT = 0xE7, - GBITDEC = 0x184, - GBITCOM = 0x185, + GBITDEC = 0x18A, + GBITCOM = 0x18B, # Secure socket INITSECURESOCKET = 0x1BA, RESPSECURESOCKET = 0x1BB diff --git a/Compiler/types.py b/Compiler/types.py index 735fddea2..93991df12 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2126,6 +2126,21 @@ def reveal_to(self, player): res = personal(player, masked.reveal() - mask[1]) return res + @set_instruction_type + @vectorize + def raw_right_shift(self, length): + """ Local right shift in supported protocols. + In integer-like protocols, the output is potentially off by one. + + :param length: number of bits + """ + res = type(self)() + shrsi(res, self, length) + return res + + def raw_mod2m(self, m): + return self - (self.raw_right_shift(m) << m) + class sint(_secret, _int): """ @@ -2668,15 +2683,6 @@ def split_to_two_summands(self, length, get_carry=False): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) - @vectorize - def raw_right_shift(self, length): - res = sint() - shrsi(res, self, length) - return res - - def raw_mod2m(self, m): - return self - (self.raw_right_shift(m) << m) - @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. diff --git a/Processor/Instruction.h b/Processor/Instruction.h index fd91e35d3..f3caf5659 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -284,8 +284,9 @@ enum // Bitwise shifts GSHLCI = 0x182, GSHRCI = 0x183, - GBITDEC = 0x184, - GBITCOM = 0x185, + GSHRSI = 0x184, + GBITDEC = 0x18A, + GBITCOM = 0x18B, // Conversion GCONVINT = 0x1C0, GCONVGF2N = 0x1C1, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index a0f7a490a..5b0589b69 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -198,6 +198,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case GORCI: case GSHLCI: case GSHRCI: + case GSHRSI: case USE: case USE_INP: case USE_EDABIT: @@ -1006,6 +1007,9 @@ inline void Instruction::execute(Processor& Proc) const case SHRSI: sint::shrsi(Procp, *this); return; + case GSHRSI: + sgf2n::shrsi(Proc2, *this); + return; case OPEN: Proc.Procp.POpen(start, Proc.P, size); return; diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 786276974..fb02d26ff 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -71,7 +71,7 @@ class RepShare : public FixedVec, public ShareInterface template static void shrsi(SubProcessor& proc, const Instruction& inst) { - shrsi(proc, inst, T::invertible); + shrsi(proc, inst, T::prime_field); } template diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index cc41d0236..3d98cf1b6 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -85,17 +85,6 @@ class Semi2kShare : public SemiShare> } } } - - template - static void shrsi(SubProcessor& proc, const Instruction& inst) - { - for (int i = 0; i < inst.get_size(); i++) - { - auto& dest = proc.get_S_ref(inst.get_r(0) + i); - auto& source = proc.get_S_ref(inst.get_r(1) + i); - dest = source >> inst.get_n(); - } - } }; #endif /* PROTOCOLS_SEMI2KSHARE_H_ */ diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 432b599bb..8d9b11466 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -130,6 +130,31 @@ class SemiShare : public T, public ShareInterface { super::unpack(os, n_bits); } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + shrsi(proc, inst, T::prime_field); + } + + template + static void shrsi(SubProcessor&, const Instruction&, + true_type) + { + throw runtime_error("shrsi not implemented"); + } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst, + false_type) + { + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest = source >> inst.get_n(); + } + } }; #endif /* PROTOCOLS_SEMISHARE_H_ */ From af5af2df251a84626dd451753039333e03ee51b7 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 23 Jun 2022 19:17:08 +0200 Subject: [PATCH 073/221] Fix bug in logistic regression benchmark. --- Programs/Source/logreg.mpc | 1 + 1 file changed, 1 insertion(+) diff --git a/Programs/Source/logreg.mpc b/Programs/Source/logreg.mpc index 492e46e02..036a6a23b 100644 --- a/Programs/Source/logreg.mpc +++ b/Programs/Source/logreg.mpc @@ -9,6 +9,7 @@ cfix.set_precision(16, 31) dim = int(program.args[1]) batch = int(program.args[2]) +ml.Layer.back_batch_size = batch try: n_iterations = int(program.args[3]) From 12a0f0c6c887658729a96425d5bb47c8f856817d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 24 Jun 2022 10:00:33 +0200 Subject: [PATCH 074/221] Fix bug when using specific port numbers. --- Networking/Server.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Networking/Server.cpp b/Networking/Server.cpp index f9ff3e897..f8b545b9f 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -176,9 +176,11 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, { pthread_create(&thread, 0, Server::start_in_thread, server = new Server(nplayers, portnum)); - N.init(my_num, portnum, my_port, hostname.c_str(), false); + bool default_port = my_port == Names::DEFAULT_PORT or my_port == portnum; + N.init(my_num, portnum, my_port, hostname.c_str(), not default_port); pthread_join(thread, 0); - N.set_server(server->get_socket()); + if (default_port) + N.set_server(server->get_socket()); delete server; } else From 31f32f5e667bdc5334a085f77ec5bd171b5c022e Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 24 Jun 2022 16:52:35 +0200 Subject: [PATCH 075/221] Fix bug in example code for adding protocols. --- Protocols/fake-stuff.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 63b058e08..564c79f9e 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -375,6 +375,13 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P, return res; } +template<> +inline GC::NoValue read_generate_write_mac_key(Player&, + string) +{ + return {}; +} + template void read_global_mac_key(const string& directory, int nparties, U& key) { From a3b7d49cb9061e7165f338e1b88381bdba4cda57 Mon Sep 17 00:00:00 2001 From: Richard Hernandez <3848345+RHG101997@users.noreply.github.com> Date: Fri, 24 Jun 2022 12:57:04 -0400 Subject: [PATCH 076/221] Small error --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18cc92ae3..ac6435805 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. -## 0.3.2 (Mai 27, 2022) +## 0.3.2 (May 27, 2022) - Secure shuffling - O(n log n) radix sorting From 8707864c30b15d651e767de17f6bdb09bcfd8bc8 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 24 Jun 2022 11:50:20 +0200 Subject: [PATCH 077/221] Improved error message for unclosed if blocks. --- Compiler/library.py | 1 + Compiler/program.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index ef2fe1ab6..cd32b84b9 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1519,6 +1519,7 @@ class State: pass state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ name='if-block') state.has_else = False + state.caller = [frame[1:] for frame in inspect.stack()[1:]] instructions.program.curr_tape.if_states.append(state) def else_then(): diff --git a/Compiler/program.py b/Compiler/program.py index 78b802e14..e06418f3d 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -776,7 +776,14 @@ def optimize(self, options): return if self.if_states: - raise CompilerError('Unclosed if/else blocks') + print('Tracebacks for open blocks:') + for state in self.if_states: + try: + print(util.format_trace(state.caller)) + except AttributeError: + pass + print() + raise CompilerError('Unclosed if/else blocks, see tracebacks above') if self.program.verbose: print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) From 1cdc6824207061268af28f71fe1a57dae93bac9f Mon Sep 17 00:00:00 2001 From: Richard Hernandez <3848345+RHG101997@users.noreply.github.com> Date: Tue, 28 Jun 2022 11:28:42 -0400 Subject: [PATCH 078/221] Duplicated word --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cd0e9781c..81e4f88dc 100644 --- a/README.md +++ b/README.md @@ -581,7 +581,7 @@ secure versions of LowGear and HighGear. In all relevant programs, option `-T` activates [TopGear](https://eprint.iacr.org/2019/035) zero-knowledge proofs in both. -Hemi and Soho denote the stripped version version of LowGear and +Hemi and Soho denote the stripped version of LowGear and HighGear, respectively, for semi-honest security similar to Semi, that is, generating additively shared Beaver triples using semi-homomorphic encryption. From 505d4838c18394e8bb87bc5bae5a8b9cc00d65ad Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 1 Jul 2022 12:18:30 +1000 Subject: [PATCH 079/221] Parameter for ring size in fake preprocessing. --- Utils/Fake-Offline.cpp | 51 +++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 823c318bb..c1d14d203 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -61,6 +61,9 @@ class FakeParams { } + template + void generate_ring(); + template void make_with_mac_key(int nplayers, int default_num, bool zero); template @@ -394,7 +397,7 @@ int main(int argc, const char** argv) 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Bit length of GF(p) field (default: 128)", // Help description. + "Bit length of GF(p) field (default: 128) and Z_2^k rings (default: 64)", // Help description. "-lgp", // Flag token. "--lgp" // Flag token. ); @@ -729,22 +732,12 @@ int FakeParams::generate() // replicated secret sharing only for three parties if (nplayers == 3) { - make_bits>({}, nplayers, nbitsp, zero); - make_basic>({}, nplayers, default_num, - zero); - make_basic>({}, nplayers, - default_num, zero); - make_with_mac_key>(nplayers, - default_num, zero); - make_mult_triples({}, nplayers, ntrip2, zero, prep_data_prefix); make_bits({}, nplayers, nbits2, zero); } else if (nplayers == 4) make_basic>({}, nplayers, default_num, zero); - make_basic>>({}, nplayers, default_num, zero); - make_basic>>({}, nplayers, default_num, zero); make_minimal({}, nplayers, default_num, zero); make_mult_triples({}, nplayers, default_num, zero, prep_data_prefix); @@ -778,6 +771,22 @@ int FakeParams::generate() generate_field(T::clear::prime_field); generate_field(true_type()); + // default + generate_ring<64>(); + + // reuse lgp for simplified interface + switch (lgp) + { + case 64: + break; +#define X(L) case L: generate_ring(); break; + X(128) X(192) X(256) + default: + cerr << "Not compiled for " << lgp << "-bit rings." << endl << "Add 'X(" + << lgp << "') to line " << (__LINE__ - 2) << " in " << __FILE__ << endl; + exit(1); + } + return 0; } @@ -803,3 +812,23 @@ void FakeParams::generate_field(true_type) default_num, zero); } } + +template +inline void FakeParams::generate_ring() +{ + if (nplayers == 3) + { + make_bits>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, + zero); + make_basic>({}, nplayers, + default_num, zero); + make_with_mac_key>(nplayers, + default_num, zero); + } + else if (nplayers == 4) + make_basic>({}, nplayers, default_num, zero); + + make_basic>>({}, nplayers, default_num, zero); + make_basic>>({}, nplayers, default_num, zero); +} From 7e2c0eda53289517eece67d8146a1c5cf689de23 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 4 Jul 2022 22:39:45 +1000 Subject: [PATCH 080/221] Splitting for any number of bits in Semi2k. --- Protocols/Semi2kShare.h | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index 3d98cf1b6..679c6bc8d 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -55,7 +55,6 @@ class Semi2kShare : public SemiShare> { auto& P = protocol.P; int my_num = P.my_num(); - assert(n_bits <= 64); int unit = GC::Clear::N_BITS; for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) { @@ -67,21 +66,27 @@ class Semi2kShare : public SemiShare> to_string(n) + "-way split not working with " + to_string(P.num_players()) + " parties"); - for (int i = 0; i < n_bits; i++) - for (int j = 0; j < n; j++) - dest.at(regs.at(n * i + j) + k) = {}; + for (int l = 0; l < n_bits; l += unit) + { + int base = l; + int n_left = min(n_bits - base, unit); + for (int i = base; i < base + n_left; i++) + for (int j = 0; j < n; j++) + dest.at(regs.at(n * i + j) + k) = {}; - square64 square; + square64 square; - for (int j = 0; j < m; j++) - square.rows[j] = Integer(source[j + start]).get(); + for (int j = 0; j < m; j++) + square.rows[j] = source[j + start].get_limb(l / unit); - square.transpose(m, n_bits); + square.transpose(m, n_left); - for (int j = 0; j < n_bits; j++) - { - auto& dest_reg = dest.at(regs.at(n * j + my_num) + k); - dest_reg = square.rows[j]; + for (int j = 0; j < n_left; j++) + { + auto& dest_reg = dest.at( + regs.at(n * (base + j) + my_num) + k); + dest_reg = square.rows[j]; + } } } } From 2a1ca6ae74350aaaeee8671ea381cb78f46ce155 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 7 Jul 2022 14:39:26 +0200 Subject: [PATCH 081/221] Fix cryptic assert statement in oram.py --- Compiler/oram.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/oram.py b/Compiler/oram.py index d4b434385..f218dfdb1 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1227,7 +1227,8 @@ def batch_init(self, values): """ Batch initalization. Obliviously shuffles and adds N entries to random leaf buckets. """ m = len(values) - assert((m & (m-1)) == 0) + if not (m & (m-1)) == 0: + raise CompilerError('Batch size must a power of 2.') if m != self.size: raise CompilerError('Batch initialization must have N values.') if self.value_type != sint: From 0d642822378bbd6f65bbf63478d3b076cd065d05 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 8 Jul 2022 14:53:21 +1000 Subject: [PATCH 082/221] Basic estimate for shuffling cost. --- Compiler/instructions.py | 41 ++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index aac0c34cf..5d6bf5fc6 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -2408,8 +2408,36 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) +class shuffle_base(base.DataInstruction): + n_relevant_parties = 2 + + @staticmethod + def logn(n): + return int(math.ceil(math.log(n, 2))) + + def add_gen_usage(self, req_node, n): + # hack for unknown usage + req_node.increment(('bit', 'inverse'), float('inf')) + # minimal usage with two relevant parties + logn = self.logn(n) + n_switches = logn * 2 ** logn + for i in range(self.n_relevant_parties): + req_node.increment((self.field_type, 'input', i), n_switches) + # multiplications for bit check + req_node.increment((self.field_type, 'triple'), + n_switches * self.n_relevant_parties) + + def add_apply_usage(self, req_node, n, record_size): + req_node.increment(('bit', 'inverse'), float('inf')) + logn = self.logn(n) + n_switches = logn * 2 ** logn * self.n_relevant_parties + if n != 2 ** logn: + record_size += 1 + req_node.increment((self.field_type, 'triple'), + n_switches * record_size) + @base.gf2n -class secshuffle(base.VectorInstruction, base.DataInstruction): +class secshuffle(base.VectorInstruction, shuffle_base): """ Secure shuffling. :param: destination (sint) @@ -2425,9 +2453,10 @@ def __init__(self, *args, **kwargs): assert len(args[0]) > args[2] def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', 0), float('inf')) + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) -class gensecshuffle(base.DataInstruction): +class gensecshuffle(shuffle_base): """ Generate secure shuffle to bit used several times. :param: destination (regint) @@ -2439,9 +2468,9 @@ class gensecshuffle(base.DataInstruction): arg_format = ['ciw','int'] def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', 0), float('inf')) + self.add_gen_usage(req_node, self.args[1]) -class applyshuffle(base.VectorInstruction, base.DataInstruction): +class applyshuffle(base.VectorInstruction, shuffle_base): """ Generate secure shuffle to bit used several times. :param: destination (sint) @@ -2461,7 +2490,7 @@ def __init__(self, *args, **kwargs): assert len(args[0]) > args[2] def add_usage(self, req_node): - req_node.increment((self.field_type, 'triple', 0), float('inf')) + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) class delshuffle(base.Instruction): """ Delete secure shuffle. From dce0b427d21e19821a15e351cf6c0a564d5138b7 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 14 Jul 2022 15:48:03 +1000 Subject: [PATCH 083/221] Missing vectorization. --- Compiler/types.py | 2 ++ SimpleOT | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 93991df12..03e84e9b9 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5063,10 +5063,12 @@ def __ge__(self, other): """ Secret floating-point comparison. """ return 1 - (self < other) + @vectorize def __gt__(self, other): """ Secret floating-point comparison. """ return self.conv(other) < self + @vectorize def __le__(self, other): """ Secret floating-point comparison. """ return self.conv(other) >= self diff --git a/SimpleOT b/SimpleOT index 84d735226..96f8a97e6 160000 --- a/SimpleOT +++ b/SimpleOT @@ -1 +1 @@ -Subproject commit 84d73522619f90ba2aabce8d660baef1442aa26d +Subproject commit 96f8a97e6c049e11059337fd33457d84cb730f4c From 6db0ed1bc59e577746e10dd56b956a136365a196 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 15 Jul 2022 11:49:21 +1000 Subject: [PATCH 084/221] Array and matrix sorting in binary circuits. --- Compiler/library.py | 4 ++++ Compiler/types.py | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index cd32b84b9..524a55e16 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -460,6 +460,10 @@ def wrapper(self, *args): return wrapper def cond_swap(x,y): + from .types import SubMultiArray + if isinstance(x, (Array, SubMultiArray)): + b = x[0] > y[0] + return list(zip(*[b.cond_swap(xx, yy) for xx, yy in zip(x, y)])) b = x < y if isinstance(x, sfloat): res = ([], []) diff --git a/Compiler/types.py b/Compiler/types.py index 03e84e9b9..e74f26309 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -452,6 +452,10 @@ def carry_out(self, a, b): s = a ^ b return a ^ (s & (self ^ a)) + def cond_swap(self, a, b): + prod = self * (a ^ b) + return a ^ prod, b ^ prod + class _gf2n(_bit): """ :math:`\mathrm{GF}(2^n)` functionality. """ @@ -5646,7 +5650,8 @@ def sort(self, n_threads=None, batcher=False, n_bits=None): :param batcher: use Batcher's odd-even mergesort in any case :param n_bits: number of bits in keys (default: global bit length) """ - if batcher or self.value_type.n_elements() > 1: + if batcher or self.value_type.n_elements() > 1 or \ + program.options.binary: library.loopy_odd_even_merge_sort(self, n_threads=n_threads) else: if n_threads or 1 > 1: @@ -5739,6 +5744,13 @@ def __iter__(self): def to_array(self): return Array(self.total_size(), self.value_type, address=self.address) + def maybe_get(self, condition, index): + return self[condition * index] + + def maybe_set(self, condition, index, value): + for i, x in enumerate(value): + self.maybe_get(condition, index).maybe_set(condition, i, x) + def assign_all(self, value): """ Assign the same value to all entries. @@ -6326,6 +6338,11 @@ def sort(self, key_indices=None, n_bits=None): :param n_bits: number of bits in keys (default: global bit length) """ + if program.options.binary: + assert key_indices is None + assert len(self.sizes) == 2 + library.loopy_odd_even_merge_sort(self) + return if key_indices is None: key_indices = (0,) * (len(self.sizes) - 1) key_indices = (None,) + util.tuplify(key_indices) From 1a9bcd25e4019a0994fd73ca3253137a05b342d2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 16 Jul 2022 18:26:10 +1000 Subject: [PATCH 085/221] Correct SimpleOT version. --- SimpleOT | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SimpleOT b/SimpleOT index 96f8a97e6..84d735226 160000 --- a/SimpleOT +++ b/SimpleOT @@ -1 +1 @@ -Subproject commit 96f8a97e6c049e11059337fd33457d84cb730f4c +Subproject commit 84d73522619f90ba2aabce8d660baef1442aa26d From 1961a78fa8c341281285be8b76ba28366c36d7ef Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 16 Jul 2022 18:27:09 +1000 Subject: [PATCH 086/221] Fixed bug in MMO with prime fields longer than 1024 bits. --- Tools/MMO.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tools/MMO.hpp b/Tools/MMO.hpp index 4309e1fee..2081df838 100644 --- a/Tools/MMO.hpp +++ b/Tools/MMO.hpp @@ -18,7 +18,7 @@ void MMO::zeroIV() { octet key[AES_BLK_SIZE]; memset(key, 0, AES_BLK_SIZE * sizeof(octet)); - key[i] = i; + key[0] = i; setIV(i, key); } } From 1bbbcd277044da826995a0cd73aa08ab667a8d94 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 18 Jul 2022 14:02:06 +1000 Subject: [PATCH 087/221] Fixed bug in Python client. --- ExternalIO/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ExternalIO/client.py b/ExternalIO/client.py index 819647ff3..c0033275a 100644 --- a/ExternalIO/client.py +++ b/ExternalIO/client.py @@ -30,7 +30,7 @@ def __init__(self, hostnames, port_base, my_client_id): self.specification.Receive(self.sockets[0]) def receive_triples(self, T, n): - triples = [[0, 0, 0] * n] + triples = [[0, 0, 0] for i in range(n)] os = octetStream() for socket in self.sockets: os.Receive(socket) @@ -51,6 +51,7 @@ def send_private_inputs(self, values): T = type(values[0]) triples = self.receive_triples(T, len(values)) os = octetStream() + assert len(values) == len(triples) for value, triple in zip(values, triples): (value + triple[0]).pack(os) for socket in self.sockets: From ac252cd951fe0b6c4a99f0d1df899de8b7cf1b8d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 19 Jul 2022 12:27:09 +1000 Subject: [PATCH 088/221] Fixed bug in MemValue of size larger than one. --- Compiler/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index e74f26309..52a17cf8c 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -6633,7 +6633,9 @@ def read(self): :return: relevant basic type instance """ self.check() if program.curr_block != self.last_write_block: - self.register = self.value_type.load_mem(self.address) + self.register = self.value_type.load_mem( + self.address, size=self.size \ + if issubclass(self.value_type, _register) else None) self.last_write_block = program.curr_block return self.register From a1074ca69a654b190d8d4ae4810629c04869c0e3 Mon Sep 17 00:00:00 2001 From: prayforwind Date: Fri, 22 Jul 2022 09:27:42 +0800 Subject: [PATCH 089/221] Fix BMR's --input-file and --output-file --- BMR/RealProgramParty.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index ae69cb7f5..64efc5506 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -110,7 +110,8 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : MC = new typename T::MAC_Check(mac_key); garble_processor.reset(program); - this->processor.open_input_file(N.my_num(), 0); + this->processor.open_input_file(N.my_num(), 0, online_opts.cmd_private_input_file); + this->processor.setup_redirection(P->my_num(), 0, online_opts, this->processor.out); shared_proc = new SubProcessor(dummy_proc, *MC, *prep, *P); From a0f5bb258e5826fd46e664d6707f4f78503a4c77 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 23 Jul 2022 09:38:39 -0700 Subject: [PATCH 090/221] Update Makefile for macs where Homebrew is installed in non-traditional locations --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 03366f89d..f0d6581b1 100644 --- a/Makefile +++ b/Makefile @@ -294,8 +294,8 @@ mpir: mpir-setup mac-setup: mac-machine-setup brew install openssl boost libsodium mpir yasm ntl - -echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I/opt/homebrew/opt/openssl/include -I/opt/homebrew/include >> CONFIG.mine - -echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L/opt/homebrew/lib -L/opt/homebrew/opt/openssl/lib >> CONFIG.mine + -echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include >> CONFIG.mine + -echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib >> CONFIG.mine # -echo USE_NTL = 1 >> CONFIG.mine ifeq ($(MACHINE), aarch64) From d39ca280e5118f62fa3455898937b1fa45ecc8c1 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 23 Jul 2022 09:48:21 -0700 Subject: [PATCH 091/221] fix sorting import bug in Compiler/types.py --- Compiler/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Compiler/types.py b/Compiler/types.py index 52a17cf8c..10b2c4242 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5657,7 +5657,7 @@ def sort(self, n_threads=None, batcher=False, n_bits=None): if n_threads or 1 > 1: raise CompilerError('multi-threaded sorting only implemented ' 'with Batcher\'s odd-even mergesort') - import sorting + from . import sorting sorting.radix_sort(self, self, n_bits=n_bits) def Array(self, size): @@ -6346,7 +6346,7 @@ def sort(self, key_indices=None, n_bits=None): if key_indices is None: key_indices = (0,) * (len(self.sizes) - 1) key_indices = (None,) + util.tuplify(key_indices) - import sorting + from . import sorting keys = self.get_vector_by_indices(*key_indices) sorting.radix_sort(keys, self, n_bits=n_bits) From 6db5f5d86187b238d927b04b3f9a445a1e90bd18 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 23 Jul 2022 10:03:53 -0700 Subject: [PATCH 092/221] update README to better represent running from other directories --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 81e4f88dc..18278c25f 100644 --- a/README.md +++ b/README.md @@ -466,21 +466,21 @@ for further examples. #### Compiling and running programs from external directories -Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: +Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all MP-SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: ``` -spdz$ cd ../ +MP-SPDZ$ cd ../ $ mkdir myprogs $ cd myprogs $ mkdir -p Programs/Source $ vi Programs/Source/test.mpc -$ ../spdz/compile.py test.mpc +$ ../MP-SPDZ/compile.py test.mpc $ ls Programs/ Bytecode Public-Input Schedules Source -$ ../spdz/Scripts/setup-online.sh +$ ../MP-SPDZ/Scripts/setup-online.sh $ ls Player-Data Programs -$ ../spdz/Scripts/run-online.sh test +$ ../MP-SPDZ/Scripts/run-online.sh test ``` ### TensorFlow inference From 101879f37a5164d0f18d51d52cdcda86a9c66b06 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 25 Jul 2022 13:40:18 +1000 Subject: [PATCH 093/221] Try loading dynamic library from root directory in scripts on Linux and macOS. --- Scripts/run-common.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 7e5e6d449..c6835069f 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -57,3 +57,6 @@ run_player() { players=${PLAYERS:-2} SPDZROOT=${SPDZROOT:-.} + +export LD_LIBRARY_PATH="$SPDZROOT:$LD_LIBRARY_PATH" +export DYLD_LIBRARY_PATH="$SPDZROOT:$DYLD_LIBRARY_PATH" From 81419ba32180c62c07e3033f7eab7c4f810b7184 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 25 Jul 2022 18:12:04 +1000 Subject: [PATCH 094/221] Fix bugs in matrix multiplication with binary circuits. --- Compiler/GC/types.py | 16 ++++++++++++++-- Compiler/types.py | 4 +++- GC/TinySecret.h | 2 +- Processor/Instruction.hpp | 16 +++++++++++++++- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 4287844a2..67410678f 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -17,6 +17,7 @@ import Compiler.GC.instructions as inst import operator import math +import itertools from functools import reduce class bits(Tape.Register, _structure, _bit): @@ -1182,9 +1183,20 @@ def __mul__(self, other): return self.from_vec(other * x for x in self.v) elif isinstance(other, sbitfixvec): return NotImplemented + other_bits = util.bit_decompose(other) + m = float('inf') + for x in itertools.chain(self.v, other_bits): + try: + m = min(m, x.n) + except: + pass + if m == 1: + op = operator.mul + else: + op = operator.and_ matrix = [] - for i, b in enumerate(util.bit_decompose(other)): - matrix.append([x & b for x in self.v[:len(self.v)-i]]) + for i, b in enumerate(other_bits): + matrix.append([op(x, b) for x in self.v[:len(self.v)-i]]) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ diff --git a/Compiler/types.py b/Compiler/types.py index 10b2c4242..1531c49d8 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5263,7 +5263,8 @@ def get_address(self, index, size=None): # length can be None for single-element arrays length = 0 base = self.address + index * self.value_type.mem_size() - if size is not None and isinstance(base, _register): + if size is not None and isinstance(base, _register) \ + and not issubclass(self.value_type, _vec): base = regint._expand_address(base, size) self.address_cache[program.curr_block, key] = \ util.untuplify([base + i * length \ @@ -6063,6 +6064,7 @@ def _(base, size): assert n_threads is None if max(res_matrix.sizes) > 1000: raise AttributeError() + self.value_type.matrix_mul A = self.get_vector() B = other.get_vector() res_matrix.assign_vector( diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 9cdde3dc7..85098d18e 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -146,7 +146,7 @@ class VectorSecret : public Secret if (this != &res) res.get_regs().assign(this->get_regs().begin(), this->get_regs().begin() - + max(size_t(n_bits), this->get_regs().size())); + + min(size_t(n_bits), this->get_regs().size())); res.resize_regs(n_bits); } diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 5b0589b69..1d7c883de 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -666,6 +666,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const return r[1] + size; else return 0; + case TRANS: + if (reg_type == SBIT) + { + int n_outputs = n; + auto& args = start; + int n_inputs = args.size() - n_outputs; + long long res = 0; + for (int i = 0; i < n_outputs; i++) + res = max(res, args[i] + DIV_CEIL(n_inputs, 64)); + for (int j = 0; j < n_inputs; j++) + res = max(res, args[n_outputs] + DIV_CEIL(n_outputs, 64)); + return res; + } + else + return 0; default: if (get_reg_type() != reg_type) return 0; @@ -731,7 +746,6 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case ANDM: case NOTS: case NOTCB: - case TRANS: size = DIV_CEIL(n, 64); break; case CONVCBIT2S: From 91960440f578909bc5f06f05ca42f7edc6d1c7ad Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 26 Jul 2022 16:04:56 +1000 Subject: [PATCH 095/221] Reveal sbitvec as list. --- Compiler/GC/types.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 67410678f..9fdb59043 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -724,15 +724,7 @@ def store_in_mem(self, address): for j, x in enumerate(v[i].bit_decompose()): x.store_in_mem(address + i + j * n) def reveal(self): - if len(self) > cbits.unit: - return self.elements()[0].reveal() - revealed = [cbit() for i in range(len(self))] - for i in range(len(self)): - try: - inst.reveal(1, revealed[i], self.v[i]) - except: - revealed[i] = cbit.conv(self.v[i]) - return cbits.get_type(len(self)).bit_compose(revealed) + return util.untuplify([x.reveal() for x in self.elements()]) @classmethod def two_power(cls, nn): return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) From 97efdbc01fa66adc91592995964e531387c370da Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 26 Jul 2022 17:05:38 +1000 Subject: [PATCH 096/221] Fix bug in preprocessing accounting. --- Compiler/library.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index 524a55e16..799f85d29 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -878,7 +878,7 @@ def loop_fn(i): # known loop count if condition(start): get_tape().req_node.children[-1].aggregator = \ - lambda x: ((stop - start) // step) * x[0] + lambda x: int(ceil(((stop - start) / step))) * x[0] def for_range(start, stop=None, step=None): """ From e1b45388768aab6465e2e6c914791903a6dabb48 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 11:17:41 -0400 Subject: [PATCH 097/221] refactor to add Compiler class --- Compiler/__init__.py | 27 --- Compiler/compilerLib.py | 451 ++++++++++++++++++++++++++++++++-------- compile.py | 106 ++-------- 3 files changed, 377 insertions(+), 207 deletions(-) diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 9a22da461..6a0d6b1d5 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -2,30 +2,3 @@ from .GC import types as GC_types import inspect from .config import * -from .compilerLib import run - - -# add all instructions to the program VARS dictionary -compilerLib.VARS = {} -instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] - -for mod in (types, GC_types): - instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ - if t[1].__module__ == mod.__name__] - -instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\ - if t[1].__module__ == library.__name__] - -for op in instr_classes: - compilerLib.VARS[op.__name__] = op - -# add open and input separately due to name conflict -compilerLib.VARS['open'] = instructions.asm_open -compilerLib.VARS['vopen'] = instructions.vasm_open -compilerLib.VARS['gopen'] = instructions.gasm_open -compilerLib.VARS['vgopen'] = instructions.vgasm_open -compilerLib.VARS['input'] = instructions.asm_input -compilerLib.VARS['ginput'] = instructions.gasm_input - -compilerLib.VARS['comparison'] = comparison -compilerLib.VARS['floatingpoint'] = floatingpoint diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index b2898e21a..591700c1e 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,94 +1,361 @@ -from Compiler.program import Program +import inspect +import os +import re +import sys +import tempfile +from optparse import OptionParser + from .GC import types as GC_types +from .program import Program, defaults -import sys -import re, tempfile, os - - -def run(args, options): - """ Compile a file and output a Program object. - - If options.merge_opens is set to True, will attempt to merge any - parallelisable open instructions. """ - - prog = Program(args, options) - VARS['program'] = prog - if options.binary: - VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) - VARS['sfix'] = GC_types.sbitfixvec - for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \ - 'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \ - 'squant': - del VARS[i] - - print('Compiling file', prog.infile) - f = open(prog.infile, 'rb') - - changed = False - if options.flow_optimization: - output = [] - if_stack = [] - for line in open(prog.infile): - if if_stack and not re.match(if_stack[-1][0], line): - if_stack.pop() - m = re.match( - '(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):', - line) - if m: - output.append('%s@for_range_opt(%s)\n' % (m.group(1), - m.group(3))) - output.append('%sdef _(%s):\n' % (m.group(1), m.group(2))) - changed = True - continue - m = re.match('(\s*)if(\W.*):', line) - if m: - if_stack.append((m.group(1), len(output))) - output.append('%s@if_(%s)\n' % (m.group(1), m.group(2))) - output.append('%sdef _():\n' % (m.group(1))) - changed = True - continue - m = re.match('(\s*)elif\s+', line) - if m: - raise CompilerError('elif not supported') - if if_stack: - m = re.match('%selse:' % if_stack[-1][0], line) - if m: - start = if_stack[-1][1] - ws = if_stack[-1][0] - output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws, - output[start]) - output.append('%s@else_\n' % ws) - output.append('%sdef _():\n' % ws) - continue - output.append(line) - if changed: - infile = tempfile.NamedTemporaryFile('w+', delete=False) - for line in output: - infile.write(line) - infile.seek(0) - else: - infile = open(prog.infile) - else: - infile = open(prog.infile) - - # make compiler modules directly accessible - sys.path.insert(0, 'Compiler') - # create the tapes - exec(compile(infile.read(), infile.name, 'exec'), VARS) - - if changed and not options.debug: - os.unlink(infile.name) - - prog.finalize() - - if prog.req_num: - print('Program requires at most:') - for x in prog.req_num.pretty(): - print(x) - - if prog.verbose: - print('Program requires:', repr(prog.req_num)) - print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) - print('Memory size:', dict(prog.allocated_mem)) - - return prog + +class Compiler: + def __init__(self): + self.usage = "usage: %prog [options] filename [args]" + self.build_option_parser() + self.VARS = {} + + def build_option_parser(self): + parser = OptionParser(usage=self.usage) + parser.add_option( + "-n", + "--nomerge", + action="store_false", + dest="merge_opens", + default=defaults.merge_opens, + help="don't attempt to merge open instructions", + ) + parser.add_option("-o", "--output", dest="outfile", help="specify output file") + parser.add_option( + "-a", + "--asm-output", + dest="asmoutfile", + help="asm output file for debugging", + ) + parser.add_option( + "-g", + "--galoissize", + dest="galois", + default=defaults.galois, + help="bit length of Galois field", + ) + parser.add_option( + "-d", + "--debug", + action="store_true", + dest="debug", + help="keep track of trace for debugging", + ) + parser.add_option( + "-c", + "--comparison", + dest="comparison", + default="log", + help="comparison variant: log|plain|inv|sinv", + ) + parser.add_option( + "-M", + "--preserve-mem-order", + action="store_true", + dest="preserve_mem_order", + default=defaults.preserve_mem_order, + help="preserve order of memory instructions; possible efficiency loss", + ) + parser.add_option( + "-O", + "--optimize-hard", + action="store_true", + dest="optimize_hard", + help="currently not in use", + ) + parser.add_option( + "-u", + "--noreallocate", + action="store_true", + dest="noreallocate", + default=defaults.noreallocate, + help="don't reallocate", + ) + parser.add_option( + "-m", + "--max-parallel-open", + dest="max_parallel_open", + default=defaults.max_parallel_open, + help="restrict number of parallel opens", + ) + parser.add_option( + "-D", + "--dead-code-elimination", + action="store_true", + dest="dead_code_elimination", + default=defaults.dead_code_elimination, + help="eliminate instructions with unused result", + ) + parser.add_option( + "-p", + "--profile", + action="store_true", + dest="profile", + help="profile compilation", + ) + parser.add_option( + "-s", + "--stop", + action="store_true", + dest="stop", + help="stop on register errors", + ) + parser.add_option( + "-R", + "--ring", + dest="ring", + default=defaults.ring, + help="bit length of ring (default: 0 for field)", + ) + parser.add_option( + "-B", + "--binary", + dest="binary", + default=defaults.binary, + help="bit length of sint in binary circuit (default: 0 for arithmetic)", + ) + parser.add_option( + "-F", + "--field", + dest="field", + default=defaults.field, + help="bit length of sint modulo prime (default: 64)", + ) + parser.add_option( + "-P", + "--prime", + dest="prime", + default=defaults.prime, + help="prime modulus (default: not specified)", + ) + parser.add_option( + "-I", + "--insecure", + action="store_true", + dest="insecure", + help="activate insecure functionality for benchmarking", + ) + parser.add_option( + "-b", + "--budget", + dest="budget", + default=defaults.budget, + help="set budget for optimized loop unrolling " "(default: 100000)", + ) + parser.add_option( + "-X", + "--mixed", + action="store_true", + dest="mixed", + help="mixing arithmetic and binary computation", + ) + parser.add_option( + "-Y", + "--edabit", + action="store_true", + dest="edabit", + help="mixing arithmetic and binary computation using edaBits", + ) + parser.add_option( + "-Z", + "--split", + default=defaults.split, + dest="split", + help="mixing arithmetic and binary computation " + "using direct conversion if supported " + "(number of parties as argument)", + ) + parser.add_option( + "-C", + "--CISC", + action="store_true", + dest="cisc", + help="faster CISC compilation mode", + ) + parser.add_option( + "-K", + "--keep-cisc", + dest="keep_cisc", + help="don't translate CISC instructions", + ) + parser.add_option( + "-l", + "--flow-optimization", + action="store_true", + dest="flow_optimization", + help="optimize control flow", + ) + parser.add_option( + "-v", + "--verbose", + action="store_true", + dest="verbose", + help="more verbose output", + ) + self.parser = parser + + def parse_args(self): + self.options, self.args = self.parser.parse_args() + if len(self.args) < 1: + self.parser.print_help() + return + + if self.options.optimize_hard: + print("Note that -O/--optimize-hard currently has no effect") + + def build_program(self): + self.prog = Program(self.args, self.options) + + def build_vars(self): + from . import comparison, floatingpoint, instructions, library, types + + # add all instructions to the program VARS dictionary + instr_classes = [ + t[1] for t in inspect.getmembers(instructions, inspect.isclass) + ] + + for mod in (types, GC_types): + instr_classes += [ + t[1] + for t in inspect.getmembers(mod, inspect.isclass) + if t[1].__module__ == mod.__name__ + ] + + instr_classes += [ + t[1] + for t in inspect.getmembers(library, inspect.isfunction) + if t[1].__module__ == library.__name__ + ] + + for op in instr_classes: + self.VARS[op.__name__] = op + + # add open and input separately due to name conflict + self.VARS["open"] = instructions.asm_open + self.VARS["vopen"] = instructions.vasm_open + self.VARS["gopen"] = instructions.gasm_open + self.VARS["vgopen"] = instructions.vgasm_open + self.VARS["input"] = instructions.asm_input + self.VARS["ginput"] = instructions.gasm_input + + self.VARS["comparison"] = comparison + self.VARS["floatingpoint"] = floatingpoint + + self.VARS["program"] = self.prog + if self.options.binary: + self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary)) + self.VARS["sfix"] = GC_types.sbitfixvec + for i in [ + "cint", + "cfix", + "cgf2n", + "sintbit", + "sgf2n", + "sgf2nint", + "sgf2nuint", + "sgf2nuint32", + "sgf2nfloat", + "sfloat", + "cfloat", + "squant", + ]: + del self.VARS[i] + + def prep_compile(self): + self.parse_args() + self.build_program() + self.build_vars() + + def compile_file(self): + """Compile a file and output a Program object. + + If options.merge_opens is set to True, will attempt to merge any + parallelisable open instructions.""" + print("Compiling file", self.prog.infile) + + with open(self.prog.infile, "rb") as f: + changed = False + if self.options.flow_optimization: + output = [] + if_stack = [] + for line in f: + if if_stack and not re.match(if_stack[-1][0], line): + if_stack.pop() + m = re.match( + r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_]+)\):", + line, + ) + if m: + output.append( + "%s@for_range_opt(%s)\n" % (m.group(1), m.group(3)) + ) + output.append("%sdef _(%s):\n" % (m.group(1), m.group(2))) + changed = True + continue + m = re.match(r"(\s*)if(\W.*):", line) + if m: + if_stack.append((m.group(1), len(output))) + output.append("%s@if_(%s)\n" % (m.group(1), m.group(2))) + output.append("%sdef _():\n" % (m.group(1))) + changed = True + continue + m = re.match(r"(\s*)elif\s+", line) + if m: + raise CompilerError("elif not supported") + if if_stack: + m = re.match("%selse:" % if_stack[-1][0], line) + if m: + start = if_stack[-1][1] + ws = if_stack[-1][0] + output[start] = re.sub( + r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start] + ) + output.append("%s@else_\n" % ws) + output.append("%sdef _():\n" % ws) + continue + output.append(line) + if changed: + infile = tempfile.NamedTemporaryFile("w+", delete=False) + for line in output: + infile.write(line) + infile.seek(0) + else: + infile = open(self.prog.infile) + else: + infile = open(self.prog.infile) + + # make compiler modules directly accessible + sys.path.insert(0, "Compiler") + # create the tapes + exec(compile(infile.read(), infile.name, "exec"), self.VARS) + + if changed and not self.options.debug: + os.unlink(infile.name) + + return self.finalize_compile() + + def compile_func(self, f): + self.prep_compile() + print(f"Compiling function: {f.__name__}") + f(self.VARS) + self.finalize_compile() + + def finalize_compile(self): + self.prog.finalize() + + if self.prog.req_num: + print("Program requires at most:") + for x in self.prog.req_num.pretty(): + print(x) + + if self.prog.verbose: + print("Program requires:", repr(self.prog.req_num)) + print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost()) + print("Memory size:", dict(self.prog.allocated_mem)) + + return self.prog diff --git a/compile.py b/compile.py index da1b69ee3..50671a044 100755 --- a/compile.py +++ b/compile.py @@ -12,100 +12,30 @@ # # See the compiler documentation at https://mp-spdz.readthedocs.io # for details on the Compiler package +from Compiler.compilerLib import Compiler -from optparse import OptionParser -from Compiler.program import defaults -import Compiler +def compilation(compiler): + prog = compiler.compile_file() -def main(): - usage = "usage: %prog [options] filename [args]" - parser = OptionParser(usage=usage) - parser.add_option("-n", "--nomerge", - action="store_false", dest="merge_opens", - default=defaults.merge_opens, - help="don't attempt to merge open instructions") - parser.add_option("-o", "--output", dest="outfile", - help="specify output file") - parser.add_option("-a", "--asm-output", dest="asmoutfile", - help="asm output file for debugging") - parser.add_option("-g", "--galoissize", dest="galois", - default=defaults.galois, - help="bit length of Galois field") - parser.add_option("-d", "--debug", action="store_true", dest="debug", - help="keep track of trace for debugging") - parser.add_option("-c", "--comparison", dest="comparison", default="log", - help="comparison variant: log|plain|inv|sinv") - parser.add_option("-M", "--preserve-mem-order", action="store_true", - dest="preserve_mem_order", - default=defaults.preserve_mem_order, - help="preserve order of memory instructions; possible efficiency loss") - parser.add_option("-O", "--optimize-hard", action="store_true", - dest="optimize_hard", help="currently not in use") - parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate", - default=defaults.noreallocate, help="don't reallocate") - parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open", - default=defaults.max_parallel_open, - help="restrict number of parallel opens") - parser.add_option("-D", "--dead-code-elimination", action="store_true", - dest="dead_code_elimination", - default=defaults.dead_code_elimination, - help="eliminate instructions with unused result") - parser.add_option("-p", "--profile", action="store_true", dest="profile", - help="profile compilation") - parser.add_option("-s", "--stop", action="store_true", dest="stop", - help="stop on register errors") - parser.add_option("-R", "--ring", dest="ring", default=defaults.ring, - help="bit length of ring (default: 0 for field)") - parser.add_option("-B", "--binary", dest="binary", default=defaults.binary, - help="bit length of sint in binary circuit (default: 0 for arithmetic)") - parser.add_option("-F", "--field", dest="field", default=defaults.field, - help="bit length of sint modulo prime (default: 64)") - parser.add_option("-P", "--prime", dest="prime", default=defaults.prime, - help="prime modulus (default: not specified)") - parser.add_option("-I", "--insecure", action="store_true", dest="insecure", - help="activate insecure functionality for benchmarking") - parser.add_option("-b", "--budget", dest="budget", default=defaults.budget, - help="set budget for optimized loop unrolling " - "(default: 100000)") - parser.add_option("-X", "--mixed", action="store_true", dest="mixed", - help="mixing arithmetic and binary computation") - parser.add_option("-Y", "--edabit", action="store_true", dest="edabit", - help="mixing arithmetic and binary computation using edaBits") - parser.add_option("-Z", "--split", default=defaults.split, dest="split", - help="mixing arithmetic and binary computation " - "using direct conversion if supported " - "(number of parties as argument)") - parser.add_option("-C", "--CISC", action="store_true", dest="cisc", - help="faster CISC compilation mode") - parser.add_option("-K", "--keep-cisc", dest="keep_cisc", - help="don't translate CISC instructions") - parser.add_option("-l", "--flow-optimization", action="store_true", - dest="flow_optimization", help="optimize control flow") - parser.add_option("-v", "--verbose", action="store_true", dest="verbose", - help="more verbose output") - options,args = parser.parse_args() - if len(args) < 1: - parser.print_help() - return + if prog.public_input_file is not None: + print( + "WARNING: %s is required to run the program" % prog.public_input_file.name + ) - if options.optimize_hard: - print('Note that -O/--optimize-hard currently has no effect') - def compilation(): - prog = Compiler.run(args, options) - - if prog.public_input_file is not None: - print('WARNING: %s is required to run the program' % \ - prog.public_input_file.name) - - if options.profile: +def main(compiler): + compiler.prep_compile() + if compiler.options.profile: import cProfile - p = cProfile.Profile().runctx('compilation()', globals(), locals()) - p.dump_stats(args[0] + '.prof') + + p = cProfile.Profile().runctx("compilation(compiler)", globals(), locals()) + p.dump_stats(compiler.args[0] + ".prof") p.print_stats(2) else: - compilation() + compilation(compiler) + -if __name__ == '__main__': - main() +if __name__ == "__main__": + compiler = Compiler() + main(compiler) From 1c6c75886f0332d3c0c6f4baefce6bdebf46b2d8 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 14:02:21 -0400 Subject: [PATCH 098/221] allow for name to be passed in for function compiler --- Compiler/compilerLib.py | 24 +- Compiler/program.py | 814 ++++++++++++++++++++++------------------ 2 files changed, 471 insertions(+), 367 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 591700c1e..bd9368ba3 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -7,11 +7,15 @@ from .GC import types as GC_types from .program import Program, defaults +from Compiler.exceptions import CompilerError class Compiler: - def __init__(self): - self.usage = "usage: %prog [options] filename [args]" + def __init__(self, usage=None): + if usage: + self.usage = usage + else: + self.usage = "usage: %prog [options] filename [args]" self.build_option_parser() self.VARS = {} @@ -201,15 +205,11 @@ def build_option_parser(self): def parse_args(self): self.options, self.args = self.parser.parse_args() - if len(self.args) < 1: - self.parser.print_help() - return - if self.options.optimize_hard: print("Note that -O/--optimize-hard currently has no effect") - def build_program(self): - self.prog = Program(self.args, self.options) + def build_program(self, name=None): + self.prog = Program(self.args, self.options, name=name) def build_vars(self): from . import comparison, floatingpoint, instructions, library, types @@ -266,9 +266,9 @@ def build_vars(self): ]: del self.VARS[i] - def prep_compile(self): + def prep_compile(self, name=None): self.parse_args() - self.build_program() + self.build_program(name=name) self.build_vars() def compile_file(self): @@ -339,8 +339,8 @@ def compile_file(self): return self.finalize_compile() - def compile_func(self, f): - self.prep_compile() + def compile_func(self, f, name): + self.prep_compile(name) print(f"Compiling function: {f.__name__}") f(self.VARS) self.finalize_compile() diff --git a/Compiler/program.py b/Compiler/program.py index e06418f3d..a94e44d3c 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -4,39 +4,40 @@ object that holds various properties of the computation. """ -from Compiler.config import * -from Compiler.exceptions import * -from Compiler.instructions_base import RegType -import Compiler.instructions -import Compiler.instructions_base -import Compiler.instructions_base as inst_base -from . import allocator as al -from . import util -import random -import time -import sys, os, errno import inspect -from collections import defaultdict, deque import itertools import math -from functools import reduce +import os import re +import sys +from collections import defaultdict, deque +from functools import reduce +import Compiler.instructions +import Compiler.instructions_base +import Compiler.instructions_base as inst_base +from Compiler.config import REG_MAX, USER_MEM, COST +from Compiler.exceptions import CompilerError +from Compiler.instructions_base import RegType + +from . import allocator as al +from . import util data_types = dict( - triple = 0, - square = 1, - bit = 2, - inverse = 3, - dabit = 4, + triple=0, + square=1, + bit=2, + inverse=3, + dabit=4, ) field_types = dict( - modp = 0, - gf2n = 1, - bit = 2, + modp=0, + gf2n=1, + bit=2, ) + class defaults: debug = False verbose = False @@ -62,8 +63,9 @@ class defaults: insecure = False keep_cisc = False + class Program(object): - """ A program consists of a list of tapes representing the whole + """A program consists of a list of tapes representing the whole computation. When compiling an :file:`.mpc` file, the single instances is @@ -71,20 +73,22 @@ class Program(object): from Python code, an instance has to be created before running any instructions. """ - def __init__(self, args, options=defaults): - from .non_linear import Ring, Prime, KnownPrime + + def __init__(self, args, options=defaults, name=None): + from .non_linear import KnownPrime, Prime + self.options = options self.verbose = options.verbose self.args = args + self.name = name self.init_names(args) self._security = 40 self.prime = None self.tapes = [] - if sum(x != 0 for x in(options.ring, options.field, - options.binary)) > 1: - raise CompilerError('can only use one out of -B, -R, -F') + if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1: + raise CompilerError("can only use one out of -B, -R, -F") if options.prime and (options.ring or options.binary): - raise CompilerError('can only use one out of -B, -R, -p') + raise CompilerError("can only use one out of -B, -R, -p") if options.ring: self.set_ring_size(int(options.ring)) else: @@ -93,19 +97,20 @@ def __init__(self, args, options=defaults): self.prime = int(options.prime) max_bit_length = int(options.prime).bit_length() - 2 if self.bit_length > max_bit_length: - raise CompilerError('integer bit length can be maximal %s' % - max_bit_length) + raise CompilerError( + "integer bit length can be maximal %s" % max_bit_length + ) self.bit_length = self.bit_length or max_bit_length self.non_linear = KnownPrime(self.prime) else: self.non_linear = Prime(self.security) if not self.bit_length: self.bit_length = 64 - print('Default bit length:', self.bit_length) - print('Default security parameter:', self.security) + print("Default bit length:", self.bit_length) + print("Default security parameter:", self.security) self.galois_length = int(options.galois) if self.verbose: - print('Galois length:', self.galois_length) + print("Galois length:", self.galois_length) self.tape_counter = 0 self._curr_tape = None self.DEBUG = options.debug @@ -119,24 +124,36 @@ def __init__(self, args, options=defaults): self.public_input_file = None self.types = {} self.budget = int(self.options.budget) - self.to_merge = [Compiler.instructions.asm_open_class, \ - Compiler.instructions.gasm_open_class, \ - Compiler.instructions.muls_class, \ - Compiler.instructions.gmuls_class, \ - Compiler.instructions.mulrs_class, \ - Compiler.instructions.gmulrs, \ - Compiler.instructions.dotprods_class, \ - Compiler.instructions.gdotprods_class, \ - Compiler.instructions.asm_input_class, \ - Compiler.instructions.gasm_input_class, - Compiler.instructions.inputfix_class, - Compiler.instructions.inputfloat_class, - Compiler.instructions.inputmixed_class, - Compiler.instructions.trunc_pr_class, - Compiler.instructions_base.Mergeable] + self.to_merge = [ + Compiler.instructions.asm_open_class, + Compiler.instructions.gasm_open_class, + Compiler.instructions.muls_class, + Compiler.instructions.gmuls_class, + Compiler.instructions.mulrs_class, + Compiler.instructions.gmulrs, + Compiler.instructions.dotprods_class, + Compiler.instructions.gdotprods_class, + Compiler.instructions.asm_input_class, + Compiler.instructions.gasm_input_class, + Compiler.instructions.inputfix_class, + Compiler.instructions.inputfloat_class, + Compiler.instructions.inputmixed_class, + Compiler.instructions.trunc_pr_class, + Compiler.instructions_base.Mergeable, + ] import Compiler.GC.instructions as gc - self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ - gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] + + self.to_merge += [ + gc.ldmsdi, + gc.stmsdi, + gc.ldmsd, + gc.stmsd, + gc.stmsdci, + gc.xors, + gc.andrs, + gc.ands, + gc.inputb, + ] self.use_trunc_pr = False """ Setting whether to use special probabilistic truncation. """ self.use_dabit = options.mixed @@ -153,7 +170,8 @@ def __init__(self, args, options=defaults): self.n_running_threads = None self.input_files = {} Program.prog = self - from . import instructions_base, instructions, types, comparison + from . import comparison, instructions, instructions_base, types + instructions.program = self instructions_base.program = self types.program = self @@ -164,53 +182,53 @@ def get_args(self): return self.args def max_par_tapes(self): - """ Upper bound on number of tapes that will be run in parallel. - (Excludes empty tapes) """ + """Upper bound on number of tapes that will be run in parallel. + (Excludes empty tapes)""" return self.n_threads - + def init_names(self, args): # ignore path to file - source must be in Programs/Source - if 'Programs' in os.listdir(os.getcwd()): + if "Programs" in os.listdir(os.getcwd()): # compile prog in ./Programs/Source directory - self.programs_dir = os.getcwd() + '/Programs' + self.programs_dir = os.getcwd() + "/Programs" else: # assume source is in main SPDZ directory - self.programs_dir = sys.path[0] + '/Programs' + self.programs_dir = sys.path[0] + "/Programs" if self.verbose: - print('Compiling program in', self.programs_dir) - + print("Compiling program in", self.programs_dir) + # create extra directories if needed - for dirname in ['Public-Input', 'Bytecode', 'Schedules']: - if not os.path.exists(self.programs_dir + '/' + dirname): - os.mkdir(self.programs_dir + '/' + dirname) - - progname = args[0].split('/')[-1] - if progname.endswith('.mpc'): - progname = progname[:-4] - - if os.path.exists(args[0]): - self.infile = args[0] - else: - self.infile = self.programs_dir + '/Source/' + progname + '.mpc' + for dirname in ["Public-Input", "Bytecode", "Schedules"]: + if not os.path.exists(self.programs_dir + "/" + dirname): + os.mkdir(self.programs_dir + "/" + dirname) + + if self.name is None: + self.name = args[0].split("/")[-1] + if self.name.endswith(".mpc"): + self.name = self.name[:-4] + + if os.path.exists(args[0]): + self.infile = args[0] + else: + self.infile = self.programs_dir + "/Source/" + self.name + ".mpc" """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames """ if self.options.outfile: - self.name = self.options.outfile + '-' + progname + self.name = self.options.outfile + "-" + self.name else: - self.name = progname + self.name = self.name if len(args) > 1: - self.name += '-' + '-'.join(re.sub('/', '_', arg) - for arg in args[1:]) - self.progname = progname + self.name += "-" + "-".join(re.sub("/", "_", arg) for arg in args[1:]) def set_ring_size(self, ring_size): from .non_linear import Ring + for tape in self.tapes: - prev = tape.req_bit_length['p'] + prev = tape.req_bit_length["p"] if prev and prev != ring_size: - raise CompilerError('cannot have different ring sizes') + raise CompilerError("cannot have different ring sizes") self.bit_length = ring_size - 1 self.non_linear = Ring(ring_size) self.options.ring = str(ring_size) @@ -234,7 +252,8 @@ def g(): :param function: Python function defining the thread :param args: arguments to the function :param name: name used for files - :param single_thread: Boolean indicating whether tape will never be run in parallel to itself + :param single_thread: Boolean indicating whether tape will + never be run in parallel to itself :returns: tape handle """ @@ -258,20 +277,22 @@ def run_tape(self, tape_index, arg): return self.run_tapes([[tape_index, arg]])[0] def run_tapes(self, args): - """ Run tapes in parallel. See :py:func:`new_tape` for an example. + """Run tapes in parallel. See :py:func:`new_tape` for an example. - :param args: list of tape handles or tuples of tape handle and extra argument (for :py:func:`~Compiler.library.get_arg`) + :param args: list of tape handles or tuples of tape handle and extra + argument (for :py:func:`~Compiler.library.get_arg`) :returns: list of thread numbers """ if not self.curr_tape.singular: - raise CompilerError('Compiler does not support ' \ - 'recursive spawning of threads') + raise CompilerError( + "Compiler does not support " "recursive spawning of threads" + ) args = [list(util.tuplify(arg)) for arg in args] singular_tapes = set() for arg in args: if self.tapes[arg[0]].singular: if arg[0] in singular_tapes: - raise CompilerError('cannot run singular tape in parallel') + raise CompilerError("cannot run singular tape in parallel") singular_tapes.add(arg[0]) assert len(arg) assert len(arg) <= 2 @@ -286,59 +307,59 @@ def run_tapes(self, args): else: thread_numbers.append(self.n_threads) self.n_threads += 1 - self.curr_tape.start_new_basicblock(name='pre-run_tape') - Compiler.instructions.run_tape(*sum(([x] + list(y) for x, y in - zip(thread_numbers, args)), [])) - self.curr_tape.start_new_basicblock(name='post-run_tape') + self.curr_tape.start_new_basicblock(name="pre-run_tape") + Compiler.instructions.run_tape( + *sum(([x] + list(y) for x, y in zip(thread_numbers, args)), []) + ) + self.curr_tape.start_new_basicblock(name="post-run_tape") for arg in args: - self.curr_tape.req_node.children.append( - self.tapes[arg[0]].req_tree) + self.curr_tape.req_node.children.append(self.tapes[arg[0]].req_tree) return thread_numbers def join_tape(self, thread_number): self.join_tapes([thread_number]) def join_tapes(self, thread_numbers): - """ Wait for completion of tapes. See :py:func:`new_tape` for an example. + """Wait for completion of tapes. See :py:func:`new_tape` for an example. :param thread_numbers: list of thread numbers """ - self.curr_tape.start_new_basicblock(name='pre-join_tape') + self.curr_tape.start_new_basicblock(name="pre-join_tape") for thread_number in thread_numbers: Compiler.instructions.join_tape(thread_number) self.curr_tape.free_threads.add(thread_number) - self.curr_tape.start_new_basicblock(name='post-join_tape') + self.curr_tape.start_new_basicblock(name="post-join_tape") def update_req(self, tape): if self.req_num is None: self.req_num = tape.req_num else: self.req_num += tape.req_num - + def write_bytes(self): - """ Write all non-empty threads and schedule to files. """ + """Write all non-empty threads and schedule to files.""" nonempty_tapes = [t for t in self.tapes] - sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name - sch_file = open(sch_filename, 'w') - print('Writing to', sch_filename) - sch_file.write(str(self.max_par_tapes()) + '\n') - sch_file.write(str(len(nonempty_tapes)) + '\n') - sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n') - sch_file.write('1 0\n') - sch_file.write('0\n') - sch_file.write(' '.join(sys.argv) + '\n') - req = max(x.req_bit_length['p'] for x in self.tapes) + sch_filename = self.programs_dir + "/Schedules/%s.sch" % self.name + sch_file = open(sch_filename, "w") + print("Writing to", sch_filename) + sch_file.write(str(self.max_par_tapes()) + "\n") + sch_file.write(str(len(nonempty_tapes)) + "\n") + sch_file.write(" ".join(tape.name for tape in nonempty_tapes) + "\n") + sch_file.write("1 0\n") + sch_file.write("0\n") + sch_file.write(" ".join(sys.argv) + "\n") + req = max(x.req_bit_length["p"] for x in self.tapes) if self.options.ring: - sch_file.write('R:%s' % self.options.ring) + sch_file.write("R:%s" % self.options.ring) elif self.options.prime: - sch_file.write('p:%s' % self.options.prime) + sch_file.write("p:%s" % self.options.prime) else: - sch_file.write('lgp:%s' % req) - sch_file.write('\n') - sch_file.write('opts: %s\n' % ' '.join(self.relevant_opts)) + sch_file.write("lgp:%s" % req) + sch_file.write("\n") + sch_file.write("opts: %s\n" % " ".join(self.relevant_opts)) for tape in self.tapes: tape.write_bytes() @@ -347,12 +368,12 @@ def finalize_tape(self, tape): tape.optimize(self.options) tape.write_bytes() if self.options.asmoutfile: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) tape.purge() - + @property def curr_tape(self): - """ The tape that is currently running.""" + """The tape that is currently running.""" if self._curr_tape is None: assert not self.tapes self._curr_tape = Tape(self.name, self) @@ -365,13 +386,13 @@ def curr_tape(self, value): @property def curr_block(self): - """ The basic block that is currently being created. """ + """The basic block that is currently being created.""" return self.curr_tape.active_basicblock - + def malloc(self, size, mem_type, reg_type=None, creator_tape=None): - """ Allocate memory from the top """ + """Allocate memory from the top""" if not isinstance(size, int): - raise CompilerError('size must be known at compile time') + raise CompilerError("size must be known at compile time") if size == 0: return if isinstance(mem_type, type): @@ -389,8 +410,7 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): single_size = size size *= self.n_running_threads else: - raise CompilerError('cannot allocate memory ' - 'outside main thread') + raise CompilerError("cannot allocate memory " "outside main thread") blocks = self.free_mem_blocks[mem_type] addr = blocks.pop(size) if addr is not None: @@ -400,24 +420,23 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) - if addr + size >= 2 ** 32: - raise CompilerError("allocation exceeded for type '%s'" % - mem_type) - self.allocated_mem_blocks[addr,mem_type] = size + if addr + size >= 2**32: + raise CompilerError("allocation exceeded for type '%s'" % mem_type) + self.allocated_mem_blocks[addr, mem_type] = size if single_size: from .library import get_thread_number, runtime_error_if + tn = get_thread_number() - runtime_error_if(tn > self.n_running_threads, 'malloc') + runtime_error_if(tn > self.n_running_threads, "malloc") return addr + single_size * (tn - 1) else: return addr def free(self, addr, mem_type): - """ Free memory """ - if self.curr_block.alloc_pool \ - is not self.curr_tape.basicblocks[0].alloc_pool: - raise CompilerError('Cannot free memory within function block') - size = self.allocated_mem_blocks.pop((addr,mem_type)) + """Free memory""" + if self.curr_block.alloc_pool is not self.curr_tape.basicblocks[0].alloc_pool: + raise CompilerError("Cannot free memory within function block") + size = self.allocated_mem_blocks.pop((addr, mem_type)) self.free_mem_blocks[mem_type].push(addr, size) def finalize(self): @@ -435,47 +454,48 @@ def finalize(self): if self.options.asmoutfile: for tape in self.tapes: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) def finalize_memory(self): - from . import library - self.curr_tape.start_new_basicblock(None, 'memory-usage') + self.curr_tape.start_new_basicblock(None, "memory-usage") # reset register counter to 0 if not self.options.noreallocate: self.curr_tape.init_registers() - for mem_type,size in sorted(self.allocated_mem.items()): + for mem_type, size in sorted(self.allocated_mem.items()): if size: - #print "Memory of type '%s' of size %d" % (mem_type, size) + # print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: self.types[mem_type].load_mem(size - 1, mem_type) else: from Compiler.types import _get_type + _get_type(mem_type).load_mem(size - 1, mem_type) if self.verbose: if self.saved: - print('Saved %s memory units through reallocation' % self.saved) + print("Saved %s memory units through reallocation" % self.saved) def public_input(self, x): - """ Append a value to the public input file. """ + """Append a value to the public input file.""" if self.public_input_file is None: - self.public_input_file = open(self.programs_dir + - '/Public-Input/%s' % self.name, 'w') - self.public_input_file.write('%s\n' % str(x)) + self.public_input_file = open( + self.programs_dir + "/Public-Input/%s" % self.name, "w" + ) + self.public_input_file.write("%s\n" % str(x)) def set_bit_length(self, bit_length): - """ Change the integer bit length for non-linear functions. """ + """Change the integer bit length for non-linear functions.""" self.bit_length = bit_length - print('Changed bit length for comparisons etc. to', bit_length) + print("Changed bit length for comparisons etc. to", bit_length) def set_security(self, security): self._security = security self.non_linear.set_security(security) - print('Changed statistical security for comparison etc. to', security) + print("Changed statistical security for comparison etc. to", security) @property def security(self): - """ The statistical security parameter for non-linear - functions. """ + """The statistical security parameter for non-linear + functions.""" return self._security @security.setter @@ -493,7 +513,7 @@ def get_tape_counter(self): @property def use_trunc_pr(self): if not self._use_trunc_pr: - self.relevant_opts.add('trunc_pr') + self.relevant_opts.add("trunc_pr") return self._use_trunc_pr @use_trunc_pr.setter @@ -501,7 +521,7 @@ def use_trunc_pr(self, change): self._use_trunc_pr = change def use_edabit(self, change=None): - """ Setting whether to use edaBits for non-linear + """Setting whether to use edaBits for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -509,7 +529,7 @@ def use_edabit(self, change=None): """ if change is None: if not self._edabit: - self.relevant_opts.add('edabit') + self.relevant_opts.add("edabit") return self._edabit else: self._edabit = change @@ -518,7 +538,7 @@ def use_edabit_for(self, *args): return True def use_split(self, change=None): - """ Setting whether to use local arithmetic-binary share + """Setting whether to use local arithmetic-binary share conversion for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -526,16 +546,16 @@ def use_split(self, change=None): """ if change is None: if not self._split: - self.relevant_opts.add('split') + self.relevant_opts.add("split") return self._split else: if change and not self.options.ring: - raise CompilerError('splitting only supported for rings') - assert change > 1 or change == False + raise CompilerError("splitting only supported for rings") + assert change > 1 or change is False self._split = change def use_square(self, change=None): - """ Setting whether to use preprocessed square tuples + """Setting whether to use preprocessed square tuples (default: false). :param change: change setting if not :py:obj:`None` @@ -559,22 +579,22 @@ def linear_rounds(self, change=None): self._linear_rounds = change def options_from_args(self): - """ Set a number of options from the command-line arguments. """ - if 'trunc_pr' in self.args: + """Set a number of options from the command-line arguments.""" + if "trunc_pr" in self.args: self.use_trunc_pr = True - if 'signed_trunc_pr' in self.args: + if "signed_trunc_pr" in self.args: self.use_trunc_pr = -1 - if 'split' in self.args or 'split3' in self.args: + if "split" in self.args or "split3" in self.args: self.use_split(3) for arg in self.args: - m = re.match('split([0-9]+)', arg) + m = re.match("split([0-9]+)", arg) if m: self.use_split(int(m.group(1))) - if 'raw' in self.args: + if "raw" in self.args: self.always_raw(True) - if 'edabit' in self.args: + if "edabit" in self.args: self.use_edabit(True) - if 'linear_rounds' in self.args: + if "linear_rounds" in self.args: self.linear_rounds(True) def disable_memory_warnings(self): @@ -583,28 +603,32 @@ def disable_memory_warnings(self): @staticmethod def read_tapes(schedule): - m = re.search(r'([^/]*)\.mpc', schedule) + m = re.search(r"([^/]*)\.mpc", schedule) if m: schedule = m.group(1) if not os.path.exists(schedule): - schedule = 'Programs/Schedules/%s.sch' % schedule + schedule = "Programs/Schedules/%s.sch" % schedule try: lines = open(schedule).readlines() except FileNotFoundError: - print('%s not found, have you compiled the program?' % schedule, - file=sys.stderr) + print( + "%s not found, have you compiled the program?" % schedule, + file=sys.stderr, + ) sys.exit(1) - for tapename in lines[2].split(' '): + for tapename in lines[2].split(" "): yield tapename.strip() + class Tape: - """ A tape contains a list of basic blocks, onto which instructions are added. """ + """A tape contains a list of basic blocks, onto which instructions are added.""" + def __init__(self, name, program): - """ Set prime p and the initial instructions and registers. """ + """Set prime p and the initial instructions and registers.""" self.program = program - name += '-%d' % program.get_tape_counter() + name += "-%d" % program.get_tape_counter() self.init_names(name) self.init_registers() self.req_tree = self.ReqNode(name) @@ -658,9 +682,9 @@ def set_return(self, previous_block, sub_block): def adjust_return(self): offset = self.sub_block.get_offset(self) self.previous_block.return_address_store.args[1] = offset - + def set_exit(self, condition, exit_true=None): - """ Sets the block which we start from next, depending on the condition. + """Sets the block which we start from next, depending on the condition. (Default is to go to next block in the list) """ @@ -668,34 +692,33 @@ def set_exit(self, condition, exit_true=None): self.exit_block = exit_true for reg in condition.get_used(): reg.can_eliminate = False - + def add_jump(self): - """ Add the jump for this block's exit condition to list of - instructions (must be done after merging) """ + """Add the jump for this block's exit condition to list of + instructions (must be done after merging)""" self.instructions.append(self.exit_condition) - + def get_offset(self, next_block): return next_block.offset - (self.offset + len(self.instructions)) - + def adjust_jump(self): - """ Set the correct relative jump offset """ + """Set the correct relative jump offset""" offset = self.get_offset(self.exit_block) self.exit_condition.set_relative_jump(offset) - #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) def purge(self, retain_usage=True): def relevant(inst): - req_node = Tape.ReqNode('') + req_node = Tape.ReqNode("") req_node.num = Tape.ReqNum() inst.add_usage(req_node) return req_node.num != {} + if retain_usage: - self.usage_instructions = list(filter(relevant, - self.instructions)) + self.usage_instructions = list(filter(relevant, self.instructions)) else: self.usage_instructions = [] if len(self.usage_instructions) > 1000: - print('Retaining %d instructions' % len(self.usage_instructions)) + print("Retaining %d instructions" % len(self.usage_instructions)) del self.instructions self.purged = True @@ -706,14 +729,14 @@ def add_usage(self, req_node): instructions = self.instructions for inst in instructions: inst.add_usage(req_node) - req_node.num['all', 'round'] += self.n_rounds - req_node.num['all', 'inv'] += self.n_to_merge + req_node.num["all", "round"] += self.n_rounds + req_node.num["all", "inv"] += self.n_to_merge def expand_cisc(self): new_instructions = [] - if self.parent.program.options.keep_cisc != None: - skip = ['LTZ', 'Trunc'] - skip += self.parent.program.options.keep_cisc.split(',') + if self.parent.program.options.keep_cisc is not None: + skip = ["LTZ", "Trunc"] + skip += self.parent.program.options.keep_cisc.split(",") else: skip = [] for inst in self.instructions: @@ -726,38 +749,38 @@ def __str__(self): return self.name def is_empty(self): - """ Returns True if the list of basic blocks is empty. + """Returns True if the list of basic blocks is empty. Note: False is returned even when tape only contains basic blocks with no instructions. However, these are removed when - optimize is called. """ + optimize is called.""" if not self.purged: - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 return self._is_empty - def start_new_basicblock(self, scope=False, name=''): + def start_new_basicblock(self, scope=False, name=""): # use False because None means no scope if scope is False: scope = self.active_basicblock - suffix = '%s-%d' % (name, self.block_counter) + suffix = "%s-%d" % (name, self.block_counter) self.block_counter += 1 - sub = self.BasicBlock(self, self.name + '-' + suffix, scope) + sub = self.BasicBlock(self, self.name + "-" + suffix, scope) self.basicblocks.append(sub) self.active_basicblock = sub self.req_node.add_block(sub) - #print 'Compiling basic block', sub.name + # print 'Compiling basic block', sub.name def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) - + def init_names(self, name): self.name = name - self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' + self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc" def purge(self): for block in self.basicblocks: block.purge() - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 del self.basicblocks del self.active_basicblock self.purged = True @@ -767,26 +790,29 @@ def wrapper(self, *args, **kwargs): if self.purged: return return function(self, *args, **kwargs) + return wrapper @unpurged def optimize(self, options): if len(self.basicblocks) == 0: - print('Tape %s is empty' % self.name) + print("Tape %s is empty" % self.name) return if self.if_states: - print('Tracebacks for open blocks:') + print("Tracebacks for open blocks:") for state in self.if_states: try: print(util.format_trace(state.caller)) except AttributeError: pass print() - raise CompilerError('Unclosed if/else blocks, see tracebacks above') + raise CompilerError("Unclosed if/else blocks, see tracebacks above") if self.program.verbose: - print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) + print( + "Processing tape", self.name, "with %d blocks" % len(self.basicblocks) + ) for block in self.basicblocks: al.determine_scope(block, options) @@ -794,41 +820,56 @@ def optimize(self, options): # merge open instructions # need to do this if there are several blocks if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: - for i,block in enumerate(self.basicblocks): + for i, block in enumerate(self.basicblocks): if len(block.instructions) > 0 and self.program.verbose: - print('Processing basic block %s, %d/%d, %d instructions' % \ - (block.name, i, len(self.basicblocks), \ - len(block.instructions))) + print( + "Processing basic block %s, %d/%d, %d instructions" + % ( + block.name, + i, + len(self.basicblocks), + len(block.instructions), + ) + ) # the next call is necessary for allocation later even without merging - merger = al.Merger(block, options, \ - tuple(self.program.to_merge)) + merger = al.Merger(block, options, tuple(self.program.to_merge)) if options.dead_code_elimination: if len(block.instructions) > 1000000: - print('Eliminate dead code...') + print("Eliminate dead code...") merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: block.used_from_scope = util.set_by_id() continue if len(block.instructions) > 1000000: - print('Merging instructions...') + print("Merging instructions...") numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) if merger.counter and self.program.verbose: - print('Block requires', \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.counter.items()))) + print( + "Block requires", + ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.counter.items()) + ), + ) if merger.counter and self.program.verbose: - print('Block requires %s rounds' % \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.rounds.items()))) + print( + "Block requires %s rounds" + % ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.rounds.items()) + ) + ) # free memory merger = None if options.dead_code_elimination: - block.instructions = [x for x in block.instructions if x is not None] + block.instructions = [ + x for x in block.instructions if x is not None + ] if not (options.merge_opens and self.merge_opens): - print('Not merging instructions in tape %s' % self.name) + print("Not merging instructions in tape %s" % self.name) if options.cisc: self.expand_cisc() @@ -853,19 +894,27 @@ def optimize(self, options): reg_counts = self.count_regs() if options.noreallocate: if self.program.verbose: - print('Tape register usage:', dict(reg_counts)) + print("Tape register usage:", dict(reg_counts)) else: if self.program.verbose: - print('Tape register usage before re-allocation:', - dict(reg_counts)) - print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) - print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) - print('Re-allocating...') + print("Tape register usage before re-allocation:", dict(reg_counts)) + print( + "modp: %d clear, %d secret" + % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) + ) + print( + "GF2N: %d clear, %d secret" + % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) + ) + print("Re-allocating...") allocator = al.StraightlineAllocator(REG_MAX, self.program) + def alloc(block): - for reg in sorted(block.used_from_scope, - key=lambda x: (x.reg_type, x.i)): + for reg in sorted( + block.used_from_scope, key=lambda x: (x.reg_type, x.i) + ): allocator.alloc_reg(reg, block.alloc_pool) + def alloc_loop(block): left = deque([block]) while left: @@ -873,73 +922,84 @@ def alloc_loop(block): alloc(block) for child in block.children: left.append(child) - for i,block in enumerate(reversed(self.basicblocks)): + + for i, block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 1000000: - print('Allocating %s, %d/%d' % \ - (block.name, i, len(self.basicblocks))) + print( + "Allocating %s, %d/%d" % (block.name, i, len(self.basicblocks)) + ) if block.exit_condition is not None: jump = block.exit_condition.get_relative_jump() - if isinstance(jump, int) and jump < 0 and \ - block.exit_block.scope is not None: + if ( + isinstance(jump, int) + and jump < 0 + and block.exit_block.scope is not None + ): alloc_loop(block.exit_block.scope) allocator.process(block.instructions, block.alloc_pool) allocator.finalize(options) if self.program.verbose: - print('Tape register usage:', dict(allocator.usage)) + print("Tape register usage:", dict(allocator.usage)) # offline data requirements if self.program.verbose: - print('Compile offline data requirements...') + print("Compile offline data requirements...") self.req_num = self.req_tree.aggregate() if self.program.verbose: - print('Tape requires', self.req_num) - for req,num in sorted(self.req_num.items()): - if num == float('inf') or num >= 2 ** 32: + print("Tape requires", self.req_num) + for req, num in sorted(self.req_num.items()): + if num == float("inf") or num >= 2**32: num = -1 if req[1] in data_types: self.basicblocks[-1].instructions.append( - Compiler.instructions.use(field_types[req[0]], \ - data_types[req[1]], num, \ - add_to_prog=False)) - elif req[1] == 'input': + Compiler.instructions.use( + field_types[req[0]], data_types[req[1]], num, add_to_prog=False + ) + ) + elif req[1] == "input": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_inp(field_types[req[0]], \ - req[2], num, \ - add_to_prog=False)) - elif req[0] == 'modp': + Compiler.instructions.use_inp( + field_types[req[0]], req[2], num, add_to_prog=False + ) + ) + elif req[0] == "modp": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'gf2n': + Compiler.instructions.use_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "gf2n": self.basicblocks[-1].instructions.append( - Compiler.instructions.guse_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'edabit': + Compiler.instructions.guse_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "edabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(False, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'sedabit': + Compiler.instructions.use_edabit( + False, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "sedabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(True, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'matmul': + Compiler.instructions.use_edabit( + True, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "matmul": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_matmul(*req[1], num, \ - add_to_prog=False)) + Compiler.instructions.use_matmul(*req[1], num, add_to_prog=False) + ) if not self.is_empty(): # bit length requirement - for x in ('p', '2'): + for x in ("p", "2"): if self.req_bit_length[x]: bl = self.req_bit_length[x] if self.program.options.ring: bl = -int(self.program.options.ring) self.basicblocks[-1].instructions.append( - Compiler.instructions.reqbl(bl, - add_to_prog=False)) + Compiler.instructions.reqbl(bl, add_to_prog=False) + ) if self.program.verbose: - print('Tape requires prime bit length', self.req_bit_length['p']) - print('Tape requires galois bit length', self.req_bit_length['2']) + print("Tape requires prime bit length", self.req_bit_length["p"]) + print("Tape requires galois bit length", self.req_bit_length["2"]) @unpurged def expand_cisc(self): @@ -948,93 +1008,99 @@ def expand_cisc(self): @unpurged def _get_instructions(self): - return itertools.chain.\ - from_iterable(b.instructions for b in self.basicblocks) + return itertools.chain.from_iterable(b.instructions for b in self.basicblocks) @unpurged def get_encoding(self): - """ Get the encoding of the program, in human-readable format. """ + """Get the encoding of the program, in human-readable format.""" return [i.get_encoding() for i in self._get_instructions() if i is not None] - + @unpurged def get_bytes(self): - """ Get the byte encoding of the program as an actual string of bytes. """ - return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) - + """Get the byte encoding of the program as an actual string of bytes.""" + return b"".join( + i.get_bytes() for i in self._get_instructions() if i is not None + ) + @unpurged def write_encoding(self, filename): - """ Write the readable encoding to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the readable encoding to a file.""" + print("Writing to", filename) + f = open(filename, "w") for line in self.get_encoding(): - f.write(str(line) + '\n') + f.write(str(line) + "\n") f.close() - + @unpurged def write_str(self, filename): - """ Write the sequence of instructions to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the sequence of instructions to a file.""" + print("Writing to", filename) + f = open(filename, "w") n = 0 for block in self.basicblocks: if block.instructions: - f.write('# %s\n' % block.name) + f.write("# %s\n" % block.name) for line in block.instructions: - f.write('%s # %d\n' % (line, n)) + f.write("%s # %d\n" % (line, n)) n += 1 f.close() - + @unpurged def write_bytes(self, filename=None): - """ Write the program's byte encoding to a file. """ + """Write the program's byte encoding to a file.""" if filename is None: filename = self.outfile - if not filename.endswith('.bc'): - filename += '.bc' - if not 'Bytecode' in filename: - filename = self.program.programs_dir + '/Bytecode/' + filename - print('Writing to', filename) - f = open(filename, 'wb') + if not filename.endswith(".bc"): + filename += ".bc" + if "Bytecode" not in filename: + filename = self.program.programs_dir + "/Bytecode/" + filename + print("Writing to", filename) + f = open(filename, "wb") for i in self._get_instructions(): if i is not None: f.write(i.get_bytes()) f.close() - + def new_reg(self, reg_type, size=None): return self.Register(reg_type, self, size=size) - + def count_regs(self, reg_type=None): if reg_type is None: return self.reg_counter else: return self.reg_counter[reg_type] - + def __str__(self): return self.name class ReqNum(defaultdict): def __init__(self, init={}): super(Tape.ReqNum, self).__init__(lambda: 0, init) + def __add__(self, other): res = Tape.ReqNum() - for i,count in list(self.items()): - res[i] += count - for i,count in list(other.items()): + for i, count in list(self.items()): + res[i] += count + for i, count in list(other.items()): res[i] += count return res + def __mul__(self, other): res = Tape.ReqNum() for i in self: res[i] = other * self[i] return res + __rmul__ = __mul__ + def set_all(self, value): - if value == float('inf') and self['all', 'inv'] > 0: - print('Going to unknown from %s' % self) + if value == float("inf") and self["all", "inv"] > 0: + print("Going to unknown from %s" % self) res = Tape.ReqNum() for i in self: res[i] = value return res + def max(self, other): res = Tape.ReqNum() for i in self: @@ -1042,82 +1108,103 @@ def max(self, other): for i in other: res[i] = max(self[i], other[i]) return res + def cost(self): - return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ - if req[1] != 'input' and req[0] != 'edabit') + return sum( + num * COST[req[0]][req[1]] + for req, num in list(self.items()) + if req[1] != "input" and req[0] != "edabit" + ) + def pretty(self): - t = lambda x: 'integer' if x == 'modp' else x + def t(x): + return "integer" if x == "modp" else x + res = [] for req, num in self.items(): domain = t(req[0]) - n = '%12.0f' % num - if req[1] == 'input': - res += ['%s %s inputs from player %d' \ - % (n, domain, req[2])] - elif domain.endswith('edabit'): - if domain == 'sedabit': - eda = 'strict edabits' + n = "%12.0f" % num + if req[1] == "input": + res += ["%s %s inputs from player %d" % (n, domain, req[2])] + elif domain.endswith("edabit"): + if domain == "sedabit": + eda = "strict edabits" else: - eda = 'loose edabits' - res += ['%s %s of length %d' % (n, eda, req[1])] - elif domain == 'matmul': - res += ['%s matrix multiplications (%dx%d * %dx%d)' % - (n, req[1][0], req[1][1], req[1][1], req[1][2])] - elif req[0] != 'all': - res += ['%s %s %ss' % (n, domain, req[1])] - if self['all','round']: - res += ['% 12.0f virtual machine rounds' % self['all','round']] + eda = "loose edabits" + res += ["%s %s of length %d" % (n, eda, req[1])] + elif domain == "matmul": + res += [ + "%s matrix multiplications (%dx%d * %dx%d)" + % (n, req[1][0], req[1][1], req[1][1], req[1][2]) + ] + elif req[0] != "all": + res += ["%s %s %ss" % (n, domain, req[1])] + if self["all", "round"]: + res += ["% 12.0f virtual machine rounds" % self["all", "round"]] return res + def __str__(self): - return ', '.join(self.pretty()) + return ", ".join(self.pretty()) + def __repr__(self): return repr(dict(self)) class ReqNode(object): - __slots__ = ['num', 'children', 'name', 'blocks'] + __slots__ = ["num", "children", "name", "blocks"] + def __init__(self, name): self.children = [] self.name = name self.blocks = [] + def aggregate(self, *args): self.num = Tape.ReqNum() for block in self.blocks: block.add_usage(self) - res = reduce(lambda x,y: x + y.aggregate(self.name), - self.children, self.num) + res = reduce( + lambda x, y: x + y.aggregate(self.name), self.children, self.num + ) return res + def increment(self, data_type, num=1): self.num[data_type] += num + def add_block(self, block): self.blocks.append(block) class ReqChild(object): - __slots__ = ['aggregator', 'nodes', 'parent'] + __slots__ = ["aggregator", "nodes", "parent"] + def __init__(self, aggregator, parent): self.aggregator = aggregator self.nodes = [] self.parent = parent + def aggregate(self, name): res = self.aggregator([node.aggregate() for node in self.nodes]) try: n_reps = self.aggregator([1]) - n_rounds = res['all', 'round'] - n_invs = res['all', 'inv'] + n_rounds = res["all", "round"] + n_invs = res["all", "inv"] if (n_invs / n_rounds) * 1000 < n_reps: - print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ - '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)) - except: + print( + self.nodes[0].blocks[0].name, + "blowing up rounds: ", + "(%d / %d) ** 3 < %d" % (n_rounds, n_reps, n_invs), + ) + except Exception: pass return res + def add_node(self, tape, name): new_node = Tape.ReqNode(name) self.nodes.append(new_node) tape.req_node = new_node - def open_scope(self, aggregator, scope=False, name=''): + def open_scope(self, aggregator, scope=False, name=""): child = self.ReqChild(aggregator, self.req_node) self.req_node.children.append(child) - child.add_node(self, '%s-%d' % (name, len(self.basicblocks))) + child.add_node(self, "%s-%d" % (name, len(self.basicblocks))) self.start_new_basicblock(name=name) return child @@ -1125,21 +1212,21 @@ def close_scope(self, outer_scope, parent_req_node, name): self.req_node = parent_req_node self.start_new_basicblock(outer_scope, name) - def require_bit_length(self, bit_length, t='p'): - if t == 'p': + def require_bit_length(self, bit_length, t="p"): + if t == "p": if self.program.prime: - if (bit_length >= self.program.prime.bit_length() - 1): + if bit_length >= self.program.prime.bit_length() - 1: raise CompilerError( - 'required bit length %d too much for %d' % \ - (bit_length, self.program.prime)) - self.req_bit_length[t] = max(bit_length + 1, \ - self.req_bit_length[t]) + "required bit length %d too much for %d" + % (bit_length, self.program.prime) + ) + self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t]) else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) @staticmethod def read_instructions(tapename): - tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb') + tape = open("Programs/Bytecode/%s.bc" % tapename, "rb") while tape.peek(): yield inst_base.ParsedInstruction(tape) @@ -1147,23 +1234,35 @@ class _no_truth(object): __slots__ = [] def __bool__(self): - raise CompilerError('Cannot derive truth value from register, ' - "consider using 'compile.py -l'") + raise CompilerError( + "Cannot derive truth value from register, " + "consider using 'compile.py -l'" + ) class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. """ - __slots__ = ["reg_type", "program", "absolute_i", "relative_i", \ - "size", "vector", "vectorbase", "caller", \ - "can_eliminate", "duplicates"] + + __slots__ = [ + "reg_type", + "program", + "absolute_i", + "relative_i", + "size", + "vector", + "vectorbase", + "caller", + "can_eliminate", + "duplicates", + ] maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): - """ Creates a new register. - reg_type must be one of those defined in RegType. """ - if Compiler.instructions_base.get_global_instruction_type() == 'gf2n': + """Creates a new register. + reg_type must be one of those defined in RegType.""" + if Compiler.instructions_base.get_global_instruction_type() == "gf2n": if reg_type == RegType.ClearModp: reg_type = RegType.ClearGF2N elif reg_type == RegType.SecretModp: @@ -1173,7 +1272,7 @@ def __init__(self, reg_type, program, size=None, i=None): if size is None: size = Compiler.instructions_base.get_global_vector_size() if size is not None and size > self.maximum_size: - raise CompilerError('vector too large: %d' % size) + raise CompilerError("vector too large: %d" % size) self.size = size self.vectorbase = self self.relative_i = 0 @@ -1183,7 +1282,7 @@ def __init__(self, reg_type, program, size=None, i=None): self.i = program.reg_counter[reg_type] program.reg_counter[reg_type] += size else: - self.i = float('inf') + self.i = float("inf") self.vector = [] self.can_eliminate = True self.duplicates = util.set_by_id([self]) @@ -1204,13 +1303,14 @@ def set_size(self, size): if self.size == size: return else: - raise CompilerError('Mismatch of instruction and register size:' - ' %s != %s' % (self.size, size)) + raise CompilerError( + "Mismatch of instruction and register size:" + " %s != %s" % (self.size, size) + ) def set_vectorbase(self, vectorbase): if self.vectorbase is not self: - raise CompilerError('Cannot assign one register' \ - 'to several vectors') + raise CompilerError("Cannot assign one register" "to several vectors") self.relative_i = self.i - vectorbase.i self.vectorbase = vectorbase @@ -1218,7 +1318,7 @@ def _new_by_number(self, i, size=1): return Tape.Register(self.reg_type, self.program, size=size, i=i) def get_vector(self, base=0, size=None): - if size == None: + if size is None: size = self.size if base == 0 and size == self.size: return self @@ -1227,7 +1327,7 @@ def get_vector(self, base=0, size=None): res = self._new_by_number(self.i + base, size=size) res.set_vectorbase(self) self.create_vector_elements() - res.vector = self.vector[base:base+size] + res.vector = self.vector[base : base + size] return res def create_vector_elements(self): @@ -1265,14 +1365,18 @@ def link(self, other): @property def is_gf2n(self): - return self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.SecretGF2N - + return ( + self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.SecretGF2N + ) + @property def is_clear(self): - return self.reg_type == RegType.ClearModp or \ - self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.ClearInt + return ( + self.reg_type == RegType.ClearModp + or self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.ClearInt + ) def __str__(self): return self.reg_type + str(self.i) From 4859a09633f4696040f8a6800a7bd35ec69b9622 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 14:53:32 -0400 Subject: [PATCH 099/221] update to use decorator --- Compiler/compilerLib.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index bd9368ba3..d695c5066 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -5,9 +5,10 @@ import tempfile from optparse import OptionParser +from Compiler.exceptions import CompilerError + from .GC import types as GC_types from .program import Program, defaults -from Compiler.exceptions import CompilerError class Compiler: @@ -339,10 +340,34 @@ def compile_file(self): return self.finalize_compile() - def compile_func(self, f, name): - self.prep_compile(name) - print(f"Compiling function: {f.__name__}") - f(self.VARS) + def register_function(self, name=None): + """ + To register a function to be compiled, use this as a decorator. + Example: + + @compiler.register_function('test-mpc') + def test_mpc(compiler): + ... + """ + + def inner(func): + self.compile_name = name or func.__name__ + self.compile_function = func + return func + + return inner + + def compile_func(self): + if not hasattr(self, "compile_name") and hasattr(self, "compile_func"): + raise CompilerError( + "No function to compile. " + "Did you decorate a function with @register_fuction(name)?" + ) + self.prep_compile(self.compile_name) + print( + f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}" + ) + self.compile_function(self) self.finalize_compile() def finalize_compile(self): From 24a7b4f69d0618bebbe29cc31533c1ca9f829061 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 15:12:40 -0400 Subject: [PATCH 100/221] add setup.py to and an example mpc program --- Programs/Source/test_args.mpc | 31 +++++++++++++++++++++++++++++++ setup.py | 7 +++++++ 2 files changed, 38 insertions(+) create mode 100644 Programs/Source/test_args.mpc create mode 100644 setup.py diff --git a/Programs/Source/test_args.mpc b/Programs/Source/test_args.mpc new file mode 100644 index 000000000..88a3a8035 --- /dev/null +++ b/Programs/Source/test_args.mpc @@ -0,0 +1,31 @@ +from Compiler.library import print_ln +from Compiler.types import Matrix, sint +from Compiler.compilerLib import Compiler + + +usage = "usage: %prog [options] [args]" +compiler = Compiler(usage=usage) +compiler.parser.add_option("--rows", dest="rows") +compiler.parser.add_option("--columns", dest="columns") +compiler.parse_args() +if not compiler.options.rows: + compiler.parser.error("--rows required") +if not compiler.options.columns: + compiler.parser.error("--columns required") + + +@compiler.register_function('testmpc') +def main(compiler): + numrows = int(compiler.options.rows) + numcolumns = int(compiler.options.columns) + rows = range(numrows) + reports = Matrix(numrows, numcolumns, sint) + reports.assign_vector( + sint.get_input_from(0, size=numrows * numcolumns) + ) + for row in rows: + print_ln(f"report[{row}]: %s", reports[row].reveal()) + + +if __name__ == "__main__": + compiler.compile_func() diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..5e850bc5e --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name='mp-spdz-compiler', + version='0.1.0', + packages=find_packages(include=['Compiler', 'Compiler.*']) +) From 7005ba4eaec426714a43892a95a0fea9bee90549 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 15:40:14 -0400 Subject: [PATCH 101/221] remove unneed compiler parameter --- Compiler/compilerLib.py | 2 +- Programs/Source/test_args.mpc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index d695c5066..113f4a8e9 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -367,7 +367,7 @@ def compile_func(self): print( f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}" ) - self.compile_function(self) + self.compile_function() self.finalize_compile() def finalize_compile(self): diff --git a/Programs/Source/test_args.mpc b/Programs/Source/test_args.mpc index 88a3a8035..a9cee12e8 100644 --- a/Programs/Source/test_args.mpc +++ b/Programs/Source/test_args.mpc @@ -15,7 +15,7 @@ if not compiler.options.columns: @compiler.register_function('testmpc') -def main(compiler): +def main(): numrows = int(compiler.options.rows) numcolumns = int(compiler.options.columns) rows = range(numrows) From 497dd79ab4e90aa4cdc24e98d292756c882b9b92 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 3 Aug 2022 18:48:51 +1000 Subject: [PATCH 102/221] Fix bug in LSB extraction. --- Compiler/comparison.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 23bee2190..84bdd22b6 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -637,6 +637,7 @@ def Mod2(a_0, a, k, kappa, signed): t = [program.curr_block.new_reg('s') for i in range(6)] c2k1 = program.curr_block.new_reg('c') PRandM(r_dprime, r_prime, [r_0], k, 1, kappa) + r_0 = r_prime mulsi(t[0], r_dprime, 2) if signed: ld2i(c2k1, k - 1) From c4c167fac7e772090941a0445902eea02848f2bd Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 5 Aug 2022 15:09:03 +1000 Subject: [PATCH 103/221] Flow optimization test. --- Programs/Source/test_flow_optimization.mpc | 23 ++++++++++++++++++++++ Scripts/test_flow_optimization.sh | 4 ++++ 2 files changed, 27 insertions(+) create mode 100644 Programs/Source/test_flow_optimization.mpc create mode 100755 Scripts/test_flow_optimization.sh diff --git a/Programs/Source/test_flow_optimization.mpc b/Programs/Source/test_flow_optimization.mpc new file mode 100644 index 000000000..ba7af6507 --- /dev/null +++ b/Programs/Source/test_flow_optimization.mpc @@ -0,0 +1,23 @@ +n = 10 ** 7 +a = regint.Array(n) +b = regint.Array(n) + +for i in range(n): + if i > 1000: + a[i] = i + + if i < 1000: + b[i] = -1 + else: + b[i] = 2 * i + +def test(a, index, value): + print_ln('expected %s got %s at %s', value, a[index], index) + crash(a[index] != value) + +test(a, 999, 0) +test(b, 999, -1) +test(a, 10000, 10000) +test(b, 10000, 20000) +test(a, 1000000, 1000000) +test(b, 1000000, 2000000) diff --git a/Scripts/test_flow_optimization.sh b/Scripts/test_flow_optimization.sh new file mode 100755 index 000000000..b9ec62f6a --- /dev/null +++ b/Scripts/test_flow_optimization.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +./compile.py -l test_flow_optimization || exit 1 +Scripts/rep-field.sh test_flow_optimization || exit 1 From 5e4e3dd1a981e202b4eb99aa2feadd81d3f6fadb Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 5 Aug 2022 10:22:01 -0400 Subject: [PATCH 104/221] load mpc file as a string, not bytes --- Compiler/compilerLib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 113f4a8e9..9e36e9f9e 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -279,7 +279,7 @@ def compile_file(self): parallelisable open instructions.""" print("Compiling file", self.prog.infile) - with open(self.prog.infile, "rb") as f: + with open(self.prog.infile, "r") as f: changed = False if self.options.flow_optimization: output = [] From a1658819cd6d3faa4ea98dcbef255bb0cf01ab47 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 6 Aug 2022 12:47:36 +1000 Subject: [PATCH 105/221] Fix bug in sintbit. --- Compiler/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 1531c49d8..128fa2039 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2821,7 +2821,9 @@ def __xor__(self, other): elif util.is_zero(other): return self elif util.is_one(other): - return 1 + res = sintbit() + submr(res, cint(1), self) + return res else: return NotImplemented From 80c7250ec8c64cbd7ec8d463fab774624ac7c519 Mon Sep 17 00:00:00 2001 From: hernan232 Date: Mon, 8 Aug 2022 11:18:23 -0500 Subject: [PATCH 106/221] Add documentation about SPDZ2k non-interactive execution and correct typos in README. --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 18278c25f..8eefe82da 100644 --- a/README.md +++ b/README.md @@ -630,6 +630,12 @@ e.g. if this machine is name `diffie` on the local network: The software uses TCP ports around 5000 by default, use the `-pn` argument to change that. +If you are using the SPDZ2k protocol in non-interactive mode to run a +program compiled with a ring size different from 64, you must specify +the ring size in the script to run the program as follows: + +`Scripts/spdz2k.sh -R tutorial` + ### Yao's garbled circuits We use half-gate garbling as described by [Zahur et @@ -796,7 +802,7 @@ with three parties overall, Party 0 and 1 run the online phase. ## BMR -BMR (Bellare-Micali-Rogaway) is a method of generating a garbled circuit +BMR (Beaver-Micali-Rogaway) is a method of generating a garbled circuit using another secure computation protocol. We have implemented BMR based on all available implementations using GF(2^128) because the nature of this field particularly suits the Free-XOR optimization for garbled From d6f843f5cf480281681f7cedfd90b7f75352056a Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 8 Aug 2022 18:14:55 +1000 Subject: [PATCH 107/221] Fix bugs in Python client. --- ExternalIO/client.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ExternalIO/client.py b/ExternalIO/client.py index c0033275a..a6fd0b035 100644 --- a/ExternalIO/client.py +++ b/ExternalIO/client.py @@ -73,20 +73,21 @@ def reset_write_head(self): self.ptr = 0 def Send(self, socket): - socket.send(struct.pack(' Date: Wed, 10 Aug 2022 12:45:19 +1000 Subject: [PATCH 108/221] Use edaBits for equality test with rings. --- Compiler/floatingpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index c596240b6..d3d3f8c50 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -28,7 +28,9 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - if program.Program.prog.use_dabit: + if program.Program.prog.use_edabit: + r_prime, r = types.sint.get_edabit(k) + elif program.Program.prog.use_dabit: rr, r = zip(*(types.sint.get_dabit() for i in range(k))) r_prime = types.sint.bit_compose(rr) else: From 3f90cc3e7c7a573447687066003155d084b71bce Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 11 Aug 2022 11:11:00 +1000 Subject: [PATCH 109/221] Fix bugs in sorting with binary circuits. --- Compiler/GC/types.py | 9 +++++++++ Compiler/types.py | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 9fdb59043..5530432bb 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -236,6 +236,11 @@ def if_else(self, x, y): This will output 1. """ return result_conv(x, y)(self & (x ^ y) ^ y) + def zero_if_not(self, condition): + if util.is_constant(condition): + return self * condition + else: + return self * cbit.conv(condition) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -491,6 +496,8 @@ def __mul__(self, other): if isinstance(other, int): return self.mul_int(other) try: + if (self.n, other.n) == (1, 1): + return self & other if min(self.n, other.n) != 1: raise NotImplementedError('high order multiplication') n = max(self.n, other.n) @@ -740,6 +747,8 @@ def bit_compose(cls, bits): bits += [0] * (n - len(bits)) assert len(bits) == n return cls.from_vec(bits) + def zero_if_not(self, condition): + return self.from_vec(x.zero_if_not(condition) for x in self.v) def __str__(self): return 'sbitvec(%d)' % n return sbitvecn diff --git a/Compiler/types.py b/Compiler/types.py index 128fa2039..75e1f58f2 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -324,6 +324,9 @@ def __abs__(self): def popcnt_bits(bits): return sum(bits) + def zero_if_not(self, condition): + return condition * self + class _int(Tape._no_truth): """ Integer functionality. """ @@ -5331,7 +5334,7 @@ def maybe_get(self, condition, index): :param condition: 0/1 (regint/cint/int) :param index: regint/cint/int """ - return condition * self[condition * index] + return self[condition * index].zero_if_not(condition) def maybe_set(self, condition, index, value): """ Change entry if condition is true. From f469dfc4735d2cb3d8293a4cc5c8c2dcb4ec9171 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Tue, 26 Jul 2022 14:07:35 +0200 Subject: [PATCH 110/221] Add the INVPERM instruction The IVMPERM instruction takes in a secret shared vector representing a permutation, and returns the corresponding secret shared inverse permutation. --- Compiler/instructions.py | 20 +++++ Compiler/instructions_base.py | 1 + Compiler/types.py | 5 ++ Processor/Instruction.h | 1 + Processor/Instruction.hpp | 7 ++ Processor/Processor.h | 3 + Processor/Processor.hpp | 16 ++++ Protocols/SecureShuffle.h | 50 +++++++++++- Protocols/SecureShuffle.hpp | 143 +++++++++++++++++++++++++++------- 9 files changed, 216 insertions(+), 30 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 5d6bf5fc6..058b6ff4f 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -2501,6 +2501,26 @@ class delshuffle(base.Instruction): code = base.opcodes['DELSHUFFLE'] arg_format = ['ci'] +class inverse_permutation(base.VectorInstruction, shuffle_base): + """ Calculate the inverse permutation of a secret permutation. + + :param: destination (sint) + :param: source (sint) + + """ + __slots__ = [] + code = base.opcodes['INVPERM'] + arg_format = ['sw', 's'] + + def __init__(self, *args, **kwargs): + super(inverse_permutation, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + + def add_usage(self, req_node): + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), 1) + + class check(base.Instruction): """ Force MAC check in current thread and all idle thread if current diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 3a56e6043..f7aa48f9b 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -111,6 +111,7 @@ GENSECSHUFFLE = 0xFB, APPLYSHUFFLE = 0xFC, DELSHUFFLE = 0xFD, + INVPERM = 0xFE, # Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Compiler/types.py b/Compiler/types.py index 75e1f58f2..a69e5e520 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2776,6 +2776,11 @@ def secure_permute(self, shuffle, unit_size=1, reverse=False): applyshuffle(res, self, unit_size, shuffle, reverse) return res + def inverse_permutation(self): + res = sint(size=self.size) + inverse_permutation(res, self) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ diff --git a/Processor/Instruction.h b/Processor/Instruction.h index f3caf5659..1de58c994 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -113,6 +113,7 @@ enum GENSECSHUFFLE = 0xFB, APPLYSHUFFLE = 0xFC, DELSHUFFLE = 0xFD, + INVPERM = 0xFE, // Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 1d7c883de..7763c8377 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -283,6 +283,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) n = get_int(s); get_vector(2, start, s); break; + // instructions with 2 register operands + case INVPERM: + get_vector(2, start, s); + break; // open instructions + read/write instructions with variable length args case OPEN: case GOPEN: @@ -1076,6 +1080,9 @@ inline void Instruction::execute(Processor& Proc) const case DELSHUFFLE: Proc.Procp.delete_shuffle(Proc.read_Ci(r[0])); return; + case INVPERM: + Proc.Procp.inverse_permutation(*this); + return; case CHECK: { CheckJob job; diff --git a/Processor/Processor.h b/Processor/Processor.h index 927e93279..e29f6eb43 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -77,6 +77,7 @@ class SubProcessor size_t generate_secure_shuffle(const Instruction& instruction); void apply_shuffle(const Instruction& instruction, int handle); void delete_shuffle(int handle); + void inverse_permutation(const Instruction& instruction); void input_personal(const vector& args); void send_personal(const vector& args); @@ -101,6 +102,8 @@ class SubProcessor { return C[i]; } + + void inverse_permutation(const Instruction &instruction, int handle); }; class ArithmeticProcessor : public ProcessorBase diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 861e8cfe0..e80df0d04 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -668,6 +668,12 @@ void SubProcessor::delete_shuffle(int handle) shuffler.del(handle); } +template +void SubProcessor::inverse_permutation(const Instruction& instruction) { + shuffler.inverse_permutation(S, instruction.get_size(), instruction.get_start()[0], + instruction.get_start()[1]); +} + template void SubProcessor::input_personal(const vector& args) { @@ -686,6 +692,16 @@ void SubProcessor::input_personal(const vector& args) S[args[i + 2] + j] = input.finalize(args[i + 1]); } +/** + * + * @tparam T + * @param args Args contains four arguments + * a[0] = the size of the input (and output) vector + * a[1] = the player to which to reveal the output + * a[2] = the memory address of the input vector (sint) (i.e. the value to reveal) + * a[3] = the memory address of the output vector (cint) (i.e. the register to store the revealed value) + * // TODO: When would there be multiple sets of arguments? (for ... i < args.size(); i += 4 ... ) + */ template void SubProcessor::private_output(const vector& args) { diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h index a90c6e64f..c1c265ea8 100644 --- a/Protocols/SecureShuffle.h +++ b/Protocols/SecureShuffle.h @@ -24,8 +24,28 @@ class SecureShuffle size_t n_shuffle; bool exact; + /** + * Generates and returns a newly generated random permutation. This permutation is generated locally. + * + * @param n The size of the permutation to generate. + * @return A vector representing a permutation, a shuffled array of integers 0 through n-1. + */ + vector generate_random_permutation(int n); + + /** + * Configure a shared waksman network from a permutation known only to config_player. + * Note that although the configuration bits of the waksman network are secret shared, + * the player that generated the permutation (config_player) knows the value of these bits. + * + * A permutation is a mapping represented as a vector. + * Each item in the vector represents the output of mapping(i) where i is the index of that item. + * e.g. [2, 4, 0, 3, 1] -> perm(1) = 4 + * + * @param config_player The player tasked with generating the random permutation from which to configure the waksman network. + * @param n_shuffle The size of the permutation to generate. + */ + void configure(int config_player, vector* perm, int n); void player_round(int config_player); - void generate(int config_player, int n_shuffle); void waksman(vector& a, int depth, int start); void cond_swap(T& x, T& y, const T& b); @@ -44,9 +64,37 @@ class SecureShuffle int generate(int n_shuffle); + /** + * + * @param a The vector of registers representing the stack // TODO: Is this correct? + * @param n The size of the input vector to shuffle + * @param unit_size Determines how many vector items constitute a single block with regards to permutation: + * i.e. input vector [1,2,3,4] with unit_size=2 under permutation map [1,0] + * would result in [3,4,1,2] + * @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to) + * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) + * @param handle The integer identifying the preconfigured waksman network (shuffle) to use. Such a handle can be obtained from calling + * @param reverse Boolean indicating whether to apply the inverse of the permutation + * @see SecureShuffle::generate for obtaining a shuffle handle + */ void apply(vector& a, size_t n, int unit_size, size_t output_base, size_t input_base, int handle, bool reverse); + /** + * Calculate the secret inverse permutation of stack given secret permutation. + * + * This method is given in [1], based on stack technique in [2]. It is used in the Compiler (high-level) implementation of Square-Root ORAM. + * + * [1] Samee Zahur, Xiao Wang, Mariana Raykova, Adrià Gascón, Jack Doerner, David Evans, and Jonathan Katz. 2016. Revisiting Square Root ORAM: Efficient Random Access in Multi-Party Computation. In IEEE S&P. + * [2] Ivan Damgård, Matthias Fitzi, Eike Kiltz, Jesper Buus Nielsen, and Tomas Toft. Unconditionally Secure Constant-rounds Multi-Party Computation for Equality, Comparison, Bits and Exponentiation. In Theory of Cryptography, 2006. + * + * @param stack The vector or registers representing the stack (?) + * @param n The size of the input vector for which to calculate the inverse permutation + * @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to) + * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) + */ + void inverse_permutation(vector& stack, size_t n, size_t output_base, size_t input_base); + void del(int handle); }; diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index d2b0676ac..5d713066c 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -58,6 +58,82 @@ void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t outpu post(a, n, output_base); } + +template +void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t output_base, + size_t input_base) { + int alice = 0; + int bob = 1; + + auto &P = proc.P; + auto &input = proc.input; + + // This method only supports two players + assert(proc.protocol.get_relevant_players().size() == 2); + // The current implementation assumes a semi-honest environment + assert(!T::malicious); + + // We are dealing directly with permutations, so the unit_size will always be 1. + this->unit_size = 1; + // We need to account for sizes which are not a power of 2 + size_t n_pow2 = (1u << int(ceil(log2(n)))); + + // Copy over the input registers + pre(stack, n, input_base); + // Alice generates stack local permutation and shares the waksman configuration bits secretly to Bob. + vector perm_alice(n_pow2); + if (P.my_num() == alice) + perm_alice = generate_random_permutation(n); + configure(alice, &perm_alice, n); + // Apply perm_alice to perm_alice to get perm_bob, + // stack permutation that we can reveal to Bob without Bob learning anything about perm_alice (since it is masked by perm_a) + iter_waksman(true); + // Store perm_bob at stack[output_base] + post(stack, n, output_base); + + // Reveal permutation perm_bob = perm_a * perm_alice + // Since this permutation is masked by perm_a, Bob learns nothing about perm + vector perm_bob(n_pow2); + typename T::PrivateOutput output(proc); + for (size_t i = 0; i < n; i++) + output.prepare_sending(stack[output_base + i], bob); + output.exchange(); + for (size_t i = 0; i < n_pow2; i++) { + // TODO: Is there a better way to convert a T::clear to int? + bigint val; + output.finalize(bob).to(val); + perm_bob[i] = (int) val.get_si(); + } + + vector perm_bob_inv(n_pow2); + if (P.my_num() == bob) { + for (int i = 0; i < (int) n; i++) + perm_bob_inv[perm_bob[i]] = i; + // Pad the permutation to n_pow2 + // Required when using waksman networks + for (int i = (int) n; i < (int) n_pow2; i++) + perm_bob_inv[i] = i; + } + + // Alice secret shares perm_a with bob + // perm_a is stored in the stack at output_base + input.reset_all(P); + if (P.my_num() == alice) { + for (int i = 0; i < (int) n; i++) + input.add_mine(perm_alice[i]); + } + input.exchange(); + for (int i = 0; i < (int) n; i++) + stack[output_base + i] = input.finalize(alice); + + // The two parties now jointly compute perm_a * perm_bob_inv to obtain perm_inv + pre(stack, n, output_base); + configure(bob, &perm_bob_inv, n); + iter_waksman(true); + // perm_inv is written back to stack[output_base] + post(stack, n, output_base); +} + template void SecureShuffle::del(int handle) { @@ -129,9 +205,27 @@ void SecureShuffle::post(vector& a, size_t n, size_t output_base) } template -void SecureShuffle::player_round(int config_player) -{ - generate(config_player, n_shuffle); +vector SecureShuffle::generate_random_permutation(int n) { + vector perm; + int n_pow2 = 1 << int(ceil(log2(n))); + int shuffle_size = n; + for (int j = 0; j < n_pow2; j++) + perm.push_back(j); + SeededPRNG G; + for (int i = 0; i < shuffle_size; i++) { + int j = G.get_uint(shuffle_size - i); + swap(perm[i], perm[i + j]); + } + + return perm; +} + +template +void SecureShuffle::player_round(int config_player) { + vector random_perm(n_shuffle); + if (proc.P.my_num() == config_player) + random_perm = generate_random_permutation(n_shuffle); + configure(config_player, &random_perm, n_shuffle); iter_waksman(); } @@ -142,9 +236,12 @@ int SecureShuffle::generate(int n_shuffle) shuffles.push_back({}); auto& shuffle = shuffles.back(); - for (auto i : proc.protocol.get_relevant_players()) - { - generate(i, n_shuffle); + for (auto i: proc.protocol.get_relevant_players()) { + vector perm; + if (proc.P.my_num() == i) + perm = generate_random_permutation(n_shuffle); + configure(i, &perm, n_shuffle); + shuffle.push_back(config); } @@ -152,39 +249,27 @@ int SecureShuffle::generate(int n_shuffle) } template -void SecureShuffle::generate(int config_player, int n) -{ - auto& P = proc.P; - auto& input = proc.input; +void SecureShuffle::configure(int config_player, vector *perm, int n) { + auto &P = proc.P; + auto &input = proc.input; input.reset_all(P); int n_pow2 = 1 << int(ceil(log2(n))); Waksman waksman(n_pow2); - if (P.my_num() == config_player) - { - vector perm; - int shuffle_size = n; - for (int j = 0; j < n_pow2; j++) - perm.push_back(j); - SeededPRNG G; - for (int i = 0; i < shuffle_size; i++) - { - int j = G.get_uint(shuffle_size - i); - swap(perm[i], perm[i + j]); - } - - auto config_bits = waksman.configure(perm); - for (size_t i = 0; i < config_bits.size(); i++) - { - auto& x = config_bits[i]; + // The player specified by config_player configures the shared waksman network + // using its personal permutation + if (P.my_num() == config_player) { + auto config_bits = waksman.configure(*perm); + for (size_t i = 0; i < config_bits.size(); i++) { + auto &x = config_bits[i]; for (size_t j = 0; j < x.size(); j++) if (waksman.matters(i, j)) input.add_mine(int(x[j])); else assert(x[j] == 0); } - } - else + // The other player waits for its share of the configured waksman network + } else for (size_t i = 0; i < waksman.n_bits(); i++) input.add_other(config_player); From 70135dd2fecd638abf8d14a9af1ce39b7e35c43a Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Tue, 26 Jul 2022 16:45:03 +0200 Subject: [PATCH 111/221] Fix segfault in INVPERM instruction --- Protocols/SecureShuffle.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index 5d713066c..920ccf3ac 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -98,7 +98,7 @@ void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t ou for (size_t i = 0; i < n; i++) output.prepare_sending(stack[output_base + i], bob); output.exchange(); - for (size_t i = 0; i < n_pow2; i++) { + for (size_t i = 0; i < n; i++) { // TODO: Is there a better way to convert a T::clear to int? bigint val; output.finalize(bob).to(val); From 52ac60beb03dc02b85894f347cfd775ad56fca51 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 11 Aug 2022 15:35:51 +0200 Subject: [PATCH 112/221] Add --invperm flag for the INVPERM instruction --- Compiler/program.py | 63 +++++++++++++++++++++++++++++---------------- Compiler/types.py | 16 ++++++++++-- compile.py | 2 ++ 3 files changed, 57 insertions(+), 24 deletions(-) diff --git a/Compiler/program.py b/Compiler/program.py index e06418f3d..435bf7e59 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -49,6 +49,7 @@ class defaults: budget = 100000 mixed = False edabit = False + invperm = False split = None cisc = False comparison = None @@ -142,6 +143,8 @@ def __init__(self, args, options=defaults): self.use_dabit = options.mixed """ Setting whether to use daBits for non-linear functionality. """ self._edabit = options.edabit + """ Whether to use the low-level INVPERM instruction (only implemented with the assumption of a semi-honest two-party environment)""" + self._invperm = options.invperm self._split = False if options.split: self.use_split(int(options.split)) @@ -167,7 +170,7 @@ def max_par_tapes(self): """ Upper bound on number of tapes that will be run in parallel. (Excludes empty tapes) """ return self.n_threads - + def init_names(self, args): # ignore path to file - source must be in Programs/Source if 'Programs' in os.listdir(os.getcwd()): @@ -178,16 +181,16 @@ def init_names(self, args): self.programs_dir = sys.path[0] + '/Programs' if self.verbose: print('Compiling program in', self.programs_dir) - + # create extra directories if needed for dirname in ['Public-Input', 'Bytecode', 'Schedules']: if not os.path.exists(self.programs_dir + '/' + dirname): os.mkdir(self.programs_dir + '/' + dirname) - + progname = args[0].split('/')[-1] if progname.endswith('.mpc'): progname = progname[:-4] - + if os.path.exists(args[0]): self.infile = args[0] else: @@ -314,7 +317,7 @@ def update_req(self, tape): self.req_num = tape.req_num else: self.req_num += tape.req_num - + def write_bytes(self): """ Write all non-empty threads and schedule to files. """ @@ -349,7 +352,7 @@ def finalize_tape(self, tape): if self.options.asmoutfile: tape.write_str(self.options.asmoutfile + '-' + tape.name) tape.purge() - + @property def curr_tape(self): """ The tape that is currently running.""" @@ -367,7 +370,7 @@ def curr_tape(self, value): def curr_block(self): """ The basic block that is currently being created. """ return self.curr_tape.active_basicblock - + def malloc(self, size, mem_type, reg_type=None, creator_tape=None): """ Allocate memory from the top """ if not isinstance(size, int): @@ -514,6 +517,20 @@ def use_edabit(self, change=None): else: self._edabit = change + def use_invperm(self, change=None): + """ Set whether to use the low-level INVPERM instruction to inverse a permutation (see sint.inverse_permutation). The INVPERM instruction assumes a semi-honest two-party environment. If false, a general protocol implemented in the high-level language is used. + + :param change: change setting if not :py:obj:`None` + :returns: setting if :py:obj:`change` is :py:obj:`None` + """ + if change is None: + if not self._invperm: + self.relevant_opts.add('invperm') + return self._invperm + else: + self._invperm = change + + def use_edabit_for(self, *args): return True @@ -574,6 +591,8 @@ def options_from_args(self): self.always_raw(True) if 'edabit' in self.args: self.use_edabit(True) + if 'invperm' in self.args: + self.use_invperm(True) if 'linear_rounds' in self.args: self.linear_rounds(True) @@ -658,7 +677,7 @@ def set_return(self, previous_block, sub_block): def adjust_return(self): offset = self.sub_block.get_offset(self) self.previous_block.return_address_store.args[1] = offset - + def set_exit(self, condition, exit_true=None): """ Sets the block which we start from next, depending on the condition. @@ -668,15 +687,15 @@ def set_exit(self, condition, exit_true=None): self.exit_block = exit_true for reg in condition.get_used(): reg.can_eliminate = False - + def add_jump(self): """ Add the jump for this block's exit condition to list of instructions (must be done after merging) """ self.instructions.append(self.exit_condition) - + def get_offset(self, next_block): return next_block.offset - (self.offset + len(self.instructions)) - + def adjust_jump(self): """ Set the correct relative jump offset """ offset = self.get_offset(self.exit_block) @@ -749,7 +768,7 @@ def start_new_basicblock(self, scope=False, name=''): def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) - + def init_names(self, name): self.name = name self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' @@ -863,7 +882,7 @@ def optimize(self, options): print('Re-allocating...') allocator = al.StraightlineAllocator(REG_MAX, self.program) def alloc(block): - for reg in sorted(block.used_from_scope, + for reg in sorted(block.used_from_scope, key=lambda x: (x.reg_type, x.i)): allocator.alloc_reg(reg, block.alloc_pool) def alloc_loop(block): @@ -955,12 +974,12 @@ def _get_instructions(self): def get_encoding(self): """ Get the encoding of the program, in human-readable format. """ return [i.get_encoding() for i in self._get_instructions() if i is not None] - + @unpurged def get_bytes(self): """ Get the byte encoding of the program as an actual string of bytes. """ return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) - + @unpurged def write_encoding(self, filename): """ Write the readable encoding to a file. """ @@ -969,7 +988,7 @@ def write_encoding(self, filename): for line in self.get_encoding(): f.write(str(line) + '\n') f.close() - + @unpurged def write_str(self, filename): """ Write the sequence of instructions to a file. """ @@ -983,7 +1002,7 @@ def write_str(self, filename): f.write('%s # %d\n' % (line, n)) n += 1 f.close() - + @unpurged def write_bytes(self, filename=None): """ Write the program's byte encoding to a file. """ @@ -999,16 +1018,16 @@ def write_bytes(self, filename=None): if i is not None: f.write(i.get_bytes()) f.close() - + def new_reg(self, reg_type, size=None): return self.Register(reg_type, self, size=size) - + def count_regs(self, reg_type=None): if reg_type is None: return self.reg_counter else: return self.reg_counter[reg_type] - + def __str__(self): return self.name @@ -1018,7 +1037,7 @@ def __init__(self, init={}): def __add__(self, other): res = Tape.ReqNum() for i,count in list(self.items()): - res[i] += count + res[i] += count for i,count in list(other.items()): res[i] += count return res @@ -1267,7 +1286,7 @@ def link(self, other): def is_gf2n(self): return self.reg_type == RegType.ClearGF2N or \ self.reg_type == RegType.SecretGF2N - + @property def is_clear(self): return self.reg_type == RegType.ClearModp or \ diff --git a/Compiler/types.py b/Compiler/types.py index a69e5e520..d63295c8f 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2777,8 +2777,20 @@ def secure_permute(self, shuffle, unit_size=1, reverse=False): return res def inverse_permutation(self): - res = sint(size=self.size) - inverse_permutation(res, self) + if program.use_invperm(): + # If enabled, we use the low-level INVPERM instruction. + # This instruction has only been implemented for a semi-honest two-party environement. + res = sint(size=self.size) + inverse_permutation(res, self) + else: + shuffle = sint.get_secure_shuffle(len(self)) + shuffled = self.secure_permute(shuffle).reveal() + idx = Array.create_from(shuffled) + res = Array.create_from(sint(regint.inc(len(self)))) + res.secure_permute(shuffle, reverse=False) + res.assign_slice_vector(idx, res.get_vector()) + library.break_point() + res = res.get_vector() return res class sintbit(sint): diff --git a/compile.py b/compile.py index da1b69ee3..2455946b3 100755 --- a/compile.py +++ b/compile.py @@ -72,6 +72,8 @@ def main(): help="mixing arithmetic and binary computation") parser.add_option("-Y", "--edabit", action="store_true", dest="edabit", help="mixing arithmetic and binary computation using edaBits") + parser.add_option("--invperm", action="store_true", dest="invperm", + help="speedup inverse permutation (only use in two-party, semi-honest environment)") parser.add_option("-Z", "--split", default=defaults.split, dest="split", help="mixing arithmetic and binary computation " "using direct conversion if supported " From 9e9210e683d29f77919c164040ba773b3154a0a3 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 28 Jul 2022 12:28:11 +0200 Subject: [PATCH 113/221] Add bit_not to MemValue --- Compiler/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Compiler/types.py b/Compiler/types.py index d63295c8f..5b7441a56 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -6708,6 +6708,7 @@ def reveal(self): if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs) bit_and = lambda self,other: self.read().bit_and(other) + bit_not = lambda self: self.read().bit_not() def expand_to_vector(self, size=None): if program.curr_block == self.last_write_block: From 06520ea7a11451bbe77fa279d7c7abc12208d786 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 28 Jul 2022 12:28:47 +0200 Subject: [PATCH 114/221] Add SqrtORAM to Compiler --- Compiler/sqrt_oram.py | 476 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 476 insertions(+) create mode 100644 Compiler/sqrt_oram.py diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py new file mode 100644 index 000000000..5525aad44 --- /dev/null +++ b/Compiler/sqrt_oram.py @@ -0,0 +1,476 @@ +from __future__ import annotations +from abc import abstractmethod +from typing import Callable, Generic, Iterable, Literal, Type, Any, TypeVar +from Compiler import library as lib +from Compiler.GC.types import cbit, sbit, sbitint, sbits +from Compiler.oram import AbstractORAM, get_n_threads +from Compiler.types import MultiArray, sgf2n, sint, _secret, MemValue, Array, _clear, sintbit, cint +import numpy as np + +debug = True +reveal = True +n_parallel = 1024 + +def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): + if isinstance(array, MultiArray): + temp = array[pos_b][:] + array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:])) + array[pos_a].assign(cond.if_else(temp, array[pos_a][:])) + if isinstance(array, Array): + temp = array[pos_b] + array[pos_b] = cond.if_else(array[pos_a], array[pos_b]) + array[pos_a] = cond.if_else(temp, array[pos_a]) + +T = TypeVar("T", sint, sbitint) +B = TypeVar("B", sintbit, sbit) + +class SqrtOram(Generic[T, B]): + # TODO: Preferably this is an Array of vectors, but this is currently not supported + # One should regard these structures as Arrays where an entry may hold more + # than one value (which is a nice property to have when using the ORAM in + # practise). + shuffle: MultiArray + stash: MultiArray + # A block has an index and data + # `shuffle` and `stash` store the data, + # `shufflei` and `stashi` store the index + shufflei: Array + stashi: Array + + shuffle_used: Array + position_map: PositionMap + + # The size of the ORAM, i.e. how many elements it stores + n: int + # The period, i.e. how many calls can be made to the ORAM before it needs to be refreshed + T: int + # Keep track of how far we are in the period, and coincidentally how large + # the stash is (each access results in a fake or real block being put on + # the stash) + t: cint + + def __init__(self, data: MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: + """Initialize a new Oblivious RAM using the "Square-Root" algorithm. + + Args: + data (MultiArray): The data with which to initialize the ORAM. For all intents and purposes, data is regarded as a one-dimensional Array. However, one may provide a MultiArray such that every "block" can hold multiple elements (an Array). + value_type (sint): The secret type to use, defaults to sint. + k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. + period (int): Leave at None, this parameter is used to recursively pass down the top-level period. + """ + self.n = len(data) + + self.value_type = value_type + if value_type != sint and value_type != sbitint: + raise Exception("The value_type must be either sint or sbitint") + self.bit_type: Type[B] = value_type.bit_type + self.index_type = value_type.get_type(int(np.ceil(np.log2(self.n)) )) + self.entry_length = entry_length + + if debug: + lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) + self.shuffle_used = cint.Array(self.n) + # Random permutation on the data + self.shuffle = data + self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) + permutation = Array.create_from(self.shuffle_the_shuffle()) + # Calculate the period if not given + # upon recursion, the period should stay the same ("in sync"), + # therefore it can be passed as a constructor parameter + self.T = int(np.ceil(np.sqrt(self.n * np.log2(self.n) - self.n + 1)) + ) if not period else period + if debug and not period: + lib.print_ln('Period set to %s', self.T) + # Initialize position map (recursive oram) + self.position_map = PositionMap.create(permutation, k + 1, self.T) + + # Initialize stash + self.stash = MultiArray((self.T, data.sizes[1]), value_type=value_type) + self.stashi = Array(self.T, value_type=value_type) + self.t = MemValue(cint(0)) + + + def read(self, index: T): + data = self.value_type.Array(self.entry_length) + return self.access(index, self.bit_type(False), data) + + def write(self, index: T, value: Array): + self.access(index, self.bit_type(True), value) + + __getitem__ = read + __setitem__ = write + + def access(self, index: T, write: B, value: Array): + if len(value) != self.entry_length: + raise Exception("A block must be of size entry_length={}".format(self.entry_length)) + # Method Blocks do not accepts arrays as arguments + # workaround by temporarily storing it as a class field + # arrays are stored in memory so this is fine + index = MemValue(index) + return Array.create_from(self._access(index, write, value[:])) + + @lib.method_block + def _access(self, index: T, write: B, *value: list[T]): + item: T = self.value_type(*value) + + if debug: + @lib.if_e(write.reveal() == 1) + def _(): + lib.print_ln('Writing to secret index %s', index.reveal()) + @lib.else_ + def __(): + lib.print_ln('Reading from secret index %s', index.reveal()) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + self.refresh() + + found: B = MemValue(self.bit_type(False)) + + # Scan through the stash + @lib.if_(self.t > 0) + def _(): + nonlocal found + found |= index == self.stashi[0] + # We ensure that if the item is found in stash, it ends up in the first + # position (more importantly, a fixed position) of the stash + # This allows us to keep track of it in an oblivious manner + @lib.for_range_opt(self.t) + def _(i): + nonlocal found + found_: B = index == self.stashi[i + 1] + swap(self.stash, 0, i, found_) + swap(self.stashi, 0, i, found_) + found |= found_ + # found = self.bit_type(found.bit_or(found_)) + # If the item was not in the stash, we move the unknown and unimportant + # stash[0] out of the way (to the end of the stash) + swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) + swap(self.stashi, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) + + if debug: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln(' Found item in stash') + @lib.else_ + def __(): + lib.print_ln(' Item not in stash') + lib.print_ln(' Moved stash[0]=(%s: %s) to stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + self.shuffle_used[physical_address] = cbit(True) + + # If the item was in the stash (thus currently residing in stash[0]), + # we place the random item retrieved from the shuffle at the end of the stash + self.stash[self.t].assign(found.if_else( + self.shuffle[physical_address][:], + self.stash[self.t][:])) + self.stashi[self.t] = found.if_else( + self.shufflei[physical_address], + self.stashi[self.t]) + # If the item was not found in the stash, + # we place the item retrieved from the shuffle in stash[0] + self.stash[0].assign(found.bit_not().if_else( + self.shuffle[physical_address][:], + self.stash[0][:])) + self.stashi[0] = found.bit_not().if_else( + self.shufflei[physical_address], + self.stashi[0]) + if debug: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[t]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) + @lib.else_ + def __(): + lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[0]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) + + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + self.stash[0].assign(write.if_else(item, self.stash[0][:])) + item=write.bit_not().if_else(self.stash[0][:], item) + return item + + + @lib.method_block + def shuffle_the_shuffle(self): + """Permute the memory using a newly generated permutation and return + the permutation that would generate this particular shuffling. + + This permutation is needed to know how to map logical addresses to + physical addresses, and is used as such by the postition map.""" + + # Random permutation on n elements + random_shuffle = sint.get_secure_shuffle(self.n) + # Apply the random permutation + lib.print_ln('\tGenerated shuffle') + self.shuffle.secure_permute(random_shuffle) + lib.print_ln('\tShuffled shuffle') + self.shufflei.secure_permute(random_shuffle) + lib.print_ln('\tShuffled shuffle indexes') + # Calculate the permutation that would have produced the newly produced + # shuffle order. This can be calculated by regarding the logical + # indexes (shufflei) as a permutation and calculating its inverse, + # i.e. find P such that P([1,2,3,...]) = shufflei. + # this is not necessarily equal to the inverse of the above generated + # random_shuffle, as the shuffle may already be out of order (e.g. when + # refreshing). + permutation = MemValue(self.shufflei[:].inverse_permutation()) + lib.print_ln('\tCalculated inverse permutation') + return permutation + + @lib.method_block + def refresh(self): + """Refresh the ORAM by reinserting the stash back into the shuffle, and + reshuffling the shuffle. + + This must happen after T (period) accesses to the ORAM.""" + lib.print_ln('Refreshing SqrtORAM') + + # Shuffle and emtpy the stash, and store elements back into shuffle + j = MemValue(cint(0,size=1)) + @lib.for_range_opt(self.n) + def _(i): + @lib.if_(self.shuffle_used[i]) + def _(): + nonlocal j + self.shuffle[i] = self.stash[j] + self.shufflei[i] = self.stashi[j] + j += 1 + + # Reset the clock + self.t.write(0) + # Reset shuffle_used + self.shuffle_used.assign_all(0) + + # Reinitialize position map + permutation = self.shuffle_the_shuffle() + # Note that we skip here the step of "packing" the permutation. + # Since the underlying memory of the position map is already aligned in + # this packed structure, we can simply overwrite the memory while + # maintaining the structure. + self.position_map.reinitialize(*permutation) + + @lib.method_block + def reinitialize(self, *data: T): + # Note that this method is only used during refresh, and as such is + # only called with a permutation as data. + + # The logical addresses of some previous permutation are irrelevant and must be reset + self.shufflei.assign([self.index_type(i) for i in range(self.n)]) + # Reset the clock + self.t.write(0) + # Reset shuffle_used + self.shuffle_used.assign_all(0) + + # Note that the self.shuffle is actually a MultiArray + # This structure is preserved while overwriting the values using + # assign_vector + self.shuffle.assign_vector(self.value_type(data, size=self.n * self.entry_length)) + permutation = self.shuffle_the_shuffle() + self.position_map.reinitialize(*permutation) + + +class PositionMap(Generic[T, B]): + PACK_LOG: int = 2 + PACK: int = 1 << PACK_LOG + + n: int # n in the paper + depth: int # k in the paper + value_type: Type[T] + + def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None: + self.n = n + self.depth=MemValue(cint(k)) + self.value_type = value_type + self.bit_type = value_type.bit_type + self.index_type = self.value_type.get_type(int(np.ceil(np.log2(n)))) + + @abstractmethod + def get_position(self, logical_address: _secret, fake: B) -> Any: + """Retrieve the block at the given (secret) logical address.""" + if debug: + lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) + + def reinitialize(self, *permutation: T): + """Reinitialize this PositionMap. + + Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`), + we cannot simply call __init__ on self. Instead, we must take care to + reuse and overwrite the same memory. + """ + ... + + @classmethod + def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = sint) -> PositionMap: + """Creates a new PositionMap. This is the method one should call when + needing a new position map. Depending on the size of the given data, it + will either instantiate a RecursivePositionMap or + a LinearPositionMap.""" + n = len(permutation) + + if n / PositionMap.PACK <= period: + if debug: + lib.print_ln('Initializing LinearPositionMap at depth %s of size %s', k, n) + res = LinearPositionMap(permutation, value_type, k=k) + else: + if debug: + lib.print_ln('Initializing RecursivePositionMap at depth %s of size %s', k, n) + res = RecursivePositionMap(permutation, period, value_type, k=k) + + return res + + +class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): + + def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k:int=-1) -> None: + PositionMap.__init__(self, len(permutation), k=k) + pack = PositionMap.PACK + + # We pack the permutation into a smaller structure, index with a new permutation + packed_size = int(np.ceil(self.n / pack)) + packed_structure = MultiArray( + (packed_size, pack), value_type=value_type) + for i in range(packed_size): + packed_structure[i] = Array.create_from( + permutation[i*pack:(i+1)*pack]) + + # TODO: Should this be n or packed_size? + SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth) + + @lib.method_block + def get_position(self, logical_address: T, fake: B) -> _clear: + super().get_position(logical_address, fake) + + pack = PositionMap.PACK + pack_log = PositionMap.PACK_LOG + + # The item at logical_address + # will be in block with index h (block.) + # at position l in block.data (block.data) + h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(logical_address).right_shift(pack_log, program.bit_length))) + l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) + + # The resulting physical address + p = MemValue(self.index_type(0)) + found: B = MemValue(self.bit_type(False)) + + # First we try and retrieve the item from the stash + + # We retrieve stash[h] + # Since h is secret, we do this by scanning the entire stash + @lib.for_range(self.t) + def _(j): + nonlocal found + condition = self.stashi[j] == h + found |= condition + # block = stash[h] + # block is itself an array (it holds a permutation) + # we need to grab block[l] + @lib.for_range(pack) + def _(i): + nonlocal condition + condition &= l == i + p.write(condition.if_else(self.stash[j][i], p)) + + if debug: + @lib.if_(condition.reveal() == 1) + def _(): + lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, j, self.stashi[j].reveal(), self.stash[j].reveal()) + + # Then we try and retrieve the item from the shuffle (the actual memory) + + if debug: + @lib.if_(found.reveal() == 0) + def _(): + lib.print_ln('\t%s Position not in stash', self.depth) + + + p_prime = self.position_map.get_position(h, found) + self.shuffle_used[p_prime] = cbit(True) + # The block retrieved from the shuffle + # Depending on whether the block has already been `found`, this block + # is either the desired block (found=False) or a random block + # (found=True) + block_p_prime: Array = self.shuffle[p_prime] + + if debug: + @lib.if_e(found.reveal() == 0) + def _(): + lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + @lib.else_ + def __(): + lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + + # We add the retrieved block from the shuffle to the stash + self.stash[self.t].assign(block_p_prime[:]) + self.stashi[self.t] = self.shufflei[p_prime] + # Increase t + self.t += 1 + + # if found or not fake + condition = self.bit_type(fake.bit_or(found.bit_not())) + # Retrieve l'th item from block + # l is secret, so we must use linear scan + @lib.for_range_opt(pack) + def _(i): + hit: B = self.bit_type(i == l) + p.write((condition & hit).if_else(block_p_prime[i], p)) + + return p.reveal() + + @lib.method_block + def reinitialize(self, *permutation: T): + SqrtOram.reinitialize(self, *permutation) + +class LinearPositionMap(PositionMap): + physical: Array + used: Array + + def __init__(self, data: Array, value_type: Type[T] = sint, k:int =-1) -> None: + PositionMap.__init__(self, len(data), value_type, k=k) + self.physical = data + self.used = self.bit_type.Array(self.n) + + @lib.method_block + def get_position(self, logical_address: T, fake: B) -> _clear: + """ + This method corresponds to GetPosBase in the paper. + """ + super().get_position(logical_address, fake) + fake = self.bit_type(fake) + + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + linear_scan = self.bit_type.Array(self.n) + @lib.for_range_opt(self.n) + def _(i): + linear_scan[i] = logical_address == i + + p: MemValue = MemValue(self.index_type(-1)) + done: B = self.bit_type(False) + + @lib.for_range_opt(self.n) + def _(j): + nonlocal done, fake + condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ + .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) + p.write(condition.if_else(self.physical[j], p)) + self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) + done = self.bit_type(condition.if_else(self.bit_type(True), done)) + + if debug: + @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical))) + def _(): + lib.runtime_error('%s Did not find requested logical_address in shuffle, something went wrong.', self.depth) + + return p.reveal() + + @lib.method_block + def reinitialize(self, *data: T): + self.physical.assign_vector(data) + self.used.assign_all(False) From b070c23a26fce2685e1871576a78d914f877b528 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 11:24:07 +0200 Subject: [PATCH 115/221] Optimize performance of SqrtORAM --- Compiler/sqrt_oram.py | 192 +++++++++++++++++++++++++----------------- 1 file changed, 115 insertions(+), 77 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 5525aad44..e949e72eb 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -1,17 +1,41 @@ from __future__ import annotations from abc import abstractmethod -from typing import Callable, Generic, Iterable, Literal, Type, Any, TypeVar +import math +from typing import Any, Generic, Type, TypeVar + +from Compiler.program import Program +from Compiler import util from Compiler import library as lib from Compiler.GC.types import cbit, sbit, sbitint, sbits -from Compiler.oram import AbstractORAM, get_n_threads -from Compiler.types import MultiArray, sgf2n, sint, _secret, MemValue, Array, _clear, sintbit, cint -import numpy as np - -debug = True -reveal = True +from Compiler.types import ( + Array, + MemValue, + MultiArray, + _clear, + _secret, + cint, + sint, + sintbit, + regint +) + +program = Program.prog + +debug = False n_parallel = 1024 +n_threads = 8 + +multithreading = True def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): + """Swap two positions in an Array if a condition is met. + + Args: + array (Array | MultiArray): The array in which to swap the first and second position + pos_a (int | cint): The first position + pos_b (int | cint): The second position + cond (sintbit | sbit): The condition determining whether to swap + """ if isinstance(array, MultiArray): temp = array[pos_b][:] array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:])) @@ -49,7 +73,7 @@ class SqrtOram(Generic[T, B]): # the stash) t: cint - def __init__(self, data: MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: + def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: """Initialize a new Oblivious RAM using the "Square-Root" algorithm. Args: @@ -64,55 +88,51 @@ def __init__(self, data: MultiArray, entry_length: int = 1, value_type: Type[T] if value_type != sint and value_type != sbitint: raise Exception("The value_type must be either sint or sbitint") self.bit_type: Type[B] = value_type.bit_type - self.index_type = value_type.get_type(int(np.ceil(np.log2(self.n)) )) + self.index_type = value_type.get_type(util.log2(self.n)) self.entry_length = entry_length if debug: lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) self.shuffle_used = cint.Array(self.n) # Random permutation on the data - self.shuffle = data + if isinstance(data, MultiArray): + self.shuffle = data + elif isinstance(data, sint): + self.shuffle = MultiArray((self.n, self.entry_length), value_type=value_type) + self.shuffle.assign_vector(data.get_vector()) + else: + raise Exception("Incorrect format.") self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given # upon recursion, the period should stay the same ("in sync"), # therefore it can be passed as a constructor parameter - self.T = int(np.ceil(np.sqrt(self.n * np.log2(self.n) - self.n + 1)) - ) if not period else period + self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period if debug and not period: lib.print_ln('Period set to %s', self.T) # Initialize position map (recursive oram) self.position_map = PositionMap.create(permutation, k + 1, self.T) # Initialize stash - self.stash = MultiArray((self.T, data.sizes[1]), value_type=value_type) + self.stash = MultiArray((self.T, entry_length), value_type=value_type) self.stashi = Array(self.T, value_type=value_type) self.t = MemValue(cint(0)) - + @lib.method_block def read(self, index: T): - data = self.value_type.Array(self.entry_length) - return self.access(index, self.bit_type(False), data) + value = self.value_type(0, size=self.entry_length) + return self.access(index, self.bit_type(False), *value) - def write(self, index: T, value: Array): - self.access(index, self.bit_type(True), value) + @lib.method_block + def write(self, index: T, value: T): + lib.runtime_error_if(value.size != self.entry_length, "A block must be of size entry_length") + self.access(index, self.bit_type(True), *value) __getitem__ = read __setitem__ = write - def access(self, index: T, write: B, value: Array): - if len(value) != self.entry_length: - raise Exception("A block must be of size entry_length={}".format(self.entry_length)) - # Method Blocks do not accepts arrays as arguments - # workaround by temporarily storing it as a class field - # arrays are stored in memory so this is fine - index = MemValue(index) - return Array.create_from(self._access(index, write, value[:])) - @lib.method_block - def _access(self, index: T, write: B, *value: list[T]): - item: T = self.value_type(*value) - + def access(self, index: T, write: B, *value: T): if debug: @lib.if_e(write.reveal() == 1) def _(): @@ -120,6 +140,7 @@ def _(): @lib.else_ def __(): lib.print_ln('Reading from secret index %s', index.reveal()) + value = self.value_type(value) # Refresh if we have performed T (period) accesses @lib.if_(self.t == self.T) @@ -136,14 +157,24 @@ def _(): # We ensure that if the item is found in stash, it ends up in the first # position (more importantly, a fixed position) of the stash # This allows us to keep track of it in an oblivious manner - @lib.for_range_opt(self.t) - def _(i): - nonlocal found - found_: B = index == self.stashi[i + 1] - swap(self.stash, 0, i, found_) - swap(self.stashi, 0, i, found_) - found |= found_ - # found = self.bit_type(found.bit_or(found_)) + if multithreading: + found_ = self.bit_type.Array(size=self.T) + @lib.multithread(8, self.T) + def _(base, size): + found_.assign_vector(self.stashi.get_vector(base, size)[:] == index, base=base) + @lib.for_range_opt(self.t - 1) + def _(i): + swap(self.stash, 0, i, found_[i]) + swap(self.stashi, 0, i, found_[i]) + found.write(sum(found_)) + else: + @lib.for_range_opt(self.t - 1) + def _(i): + nonlocal found + found_: B = index == self.stashi[i + 1] + swap(self.stash, 0, i, found_) + swap(self.stashi, 0, i, found_) + found |= found_ # If the item was not in the stash, we move the unknown and unimportant # stash[0] out of the way (to the end of the stash) swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) @@ -156,7 +187,7 @@ def _(): @lib.else_ def __(): lib.print_ln(' Item not in stash') - lib.print_ln(' Moved stash[0]=(%s: %s) to stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + lib.print_ln(' Moved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -171,14 +202,15 @@ def __(): self.stashi[self.t] = found.if_else( self.shufflei[physical_address], self.stashi[self.t]) - # If the item was not found in the stash, - # we place the item retrieved from the shuffle in stash[0] + # If the item was not found in the stash, we place the item retrieved + # from the shuffle (the item we are actually looking for) in stash[0] self.stash[0].assign(found.bit_not().if_else( self.shuffle[physical_address][:], self.stash[0][:])) self.stashi[0] = found.bit_not().if_else( self.shufflei[physical_address], self.stashi[0]) + if debug: @lib.if_e(found.reveal() == 1) def _(): @@ -191,9 +223,9 @@ def __(): # Increase the "time" (i.e. access count in current period) self.t.iadd(1) - self.stash[0].assign(write.if_else(item, self.stash[0][:])) - item=write.bit_not().if_else(self.stash[0][:], item) - return item + self.stash[0].assign(write.if_else(value, self.stash[0][:])) + value=write.bit_not().if_else(self.stash[0][:], value) + return value @lib.method_block @@ -206,12 +238,12 @@ def shuffle_the_shuffle(self): # Random permutation on n elements random_shuffle = sint.get_secure_shuffle(self.n) + if debug: lib.print_ln('\tGenerated shuffle') # Apply the random permutation - lib.print_ln('\tGenerated shuffle') self.shuffle.secure_permute(random_shuffle) - lib.print_ln('\tShuffled shuffle') + if debug: lib.print_ln('\tShuffled shuffle') self.shufflei.secure_permute(random_shuffle) - lib.print_ln('\tShuffled shuffle indexes') + if debug: lib.print_ln('\tShuffled shuffle indexes') # Calculate the permutation that would have produced the newly produced # shuffle order. This can be calculated by regarding the logical # indexes (shufflei) as a permutation and calculating its inverse, @@ -220,7 +252,7 @@ def shuffle_the_shuffle(self): # random_shuffle, as the shuffle may already be out of order (e.g. when # refreshing). permutation = MemValue(self.shufflei[:].inverse_permutation()) - lib.print_ln('\tCalculated inverse permutation') + if debug: lib.print_ln('\tCalculated inverse permutation') return permutation @lib.method_block @@ -229,7 +261,8 @@ def refresh(self): reshuffling the shuffle. This must happen after T (period) accesses to the ORAM.""" - lib.print_ln('Refreshing SqrtORAM') + + if debug: lib.print_ln('Refreshing SqrtORAM') # Shuffle and emtpy the stash, and store elements back into shuffle j = MemValue(cint(0,size=1)) @@ -276,7 +309,7 @@ def reinitialize(self, *data: T): class PositionMap(Generic[T, B]): - PACK_LOG: int = 2 + PACK_LOG: int = 3 PACK: int = 1 << PACK_LOG n: int # n in the paper @@ -288,7 +321,7 @@ def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None: self.depth=MemValue(cint(k)) self.value_type = value_type self.bit_type = value_type.bit_type - self.index_type = self.value_type.get_type(int(np.ceil(np.log2(n)))) + self.index_type = self.value_type.get_type(util.log2(n)) @abstractmethod def get_position(self, logical_address: _secret, fake: B) -> Any: @@ -332,7 +365,7 @@ def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, pack = PositionMap.PACK # We pack the permutation into a smaller structure, index with a new permutation - packed_size = int(np.ceil(self.n / pack)) + packed_size = int(math.ceil(self.n / pack)) packed_structure = MultiArray( (packed_size, pack), value_type=value_type) for i in range(packed_size): @@ -359,28 +392,33 @@ def get_position(self, logical_address: T, fake: B) -> _clear: p = MemValue(self.index_type(0)) found: B = MemValue(self.bit_type(False)) - # First we try and retrieve the item from the stash + # First we try and retrieve the item from the stash at position stash[h][l] + # Since h and l are secret, we do this by scanning the entire stash - # We retrieve stash[h] - # Since h is secret, we do this by scanning the entire stash + # First we scan the stash for the block we need + condition1 = self.bit_type.Array(self.T) + @lib.for_range_opt_multithread(8, self.T) + def _(i): + condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t) + found = sum(condition1) + # Once a block is found, we use condition2 to pick the correct item from that block + condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack)) + # condition3 combines condition1 & condition2, only returning true at stash[h][l] + condition3 = self.bit_type.Array(self.T * pack) + @lib.for_range_opt_multithread(8, [self.T, pack]) + def _(i, j): + condition3[i*pack + j] = condition1[i] & condition2[j] + # Finally we use condition3 to conditionally write p @lib.for_range(self.t) - def _(j): - nonlocal found - condition = self.stashi[j] == h - found |= condition - # block = stash[h] - # block is itself an array (it holds a permutation) - # we need to grab block[l] + def _(i): @lib.for_range(pack) - def _(i): - nonlocal condition - condition &= l == i - p.write(condition.if_else(self.stash[j][i], p)) + def _(j): + p.write(condition3[i*pack + j].if_else(self.stash[i][j], p)) if debug: - @lib.if_(condition.reveal() == 1) + @lib.if_(condition1[i].reveal() == 1) def _(): - lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, j, self.stashi[j].reveal(), self.stash[j].reveal()) + lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal()) # Then we try and retrieve the item from the shuffle (the actual memory) @@ -389,22 +427,22 @@ def _(): def _(): lib.print_ln('\t%s Position not in stash', self.depth) - + # Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle p_prime = self.position_map.get_position(h, found) self.shuffle_used[p_prime] = cbit(True) + # The block retrieved from the shuffle - # Depending on whether the block has already been `found`, this block - # is either the desired block (found=False) or a random block - # (found=True) block_p_prime: Array = self.shuffle[p_prime] if debug: @lib.if_e(found.reveal() == 0) def _(): - lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) @lib.else_ def __(): - lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)', + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) # We add the retrieved block from the shuffle to the stash self.stash[self.t].assign(block_p_prime[:]) @@ -413,13 +451,13 @@ def __(): self.t += 1 # if found or not fake - condition = self.bit_type(fake.bit_or(found.bit_not())) + condition: B = self.bit_type(fake.bit_or(found.bit_not())) # Retrieve l'th item from block # l is secret, so we must use linear scan + hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack)) @lib.for_range_opt(pack) def _(i): - hit: B = self.bit_type(i == l) - p.write((condition & hit).if_else(block_p_prime[i], p)) + p.write((hit[i]).if_else(block_p_prime[i], p)) return p.reveal() From 3a4cceeedf9c4e2f6555d4b39b595972e5d4aa56 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 14:41:51 +0200 Subject: [PATCH 116/221] Fix misaligned stashi bug in sqrt_oram --- Compiler/sqrt_oram.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index e949e72eb..c4b14f99b 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -141,6 +141,7 @@ def _(): def __(): lib.print_ln('Reading from secret index %s', index.reveal()) value = self.value_type(value) + index = MemValue(index) # Refresh if we have performed T (period) accesses @lib.if_(self.t == self.T) @@ -159,21 +160,23 @@ def _(): # This allows us to keep track of it in an oblivious manner if multithreading: found_ = self.bit_type.Array(size=self.T) - @lib.multithread(8, self.T) + @lib.multithread(1, self.T) def _(base, size): - found_.assign_vector(self.stashi.get_vector(base, size)[:] == index, base=base) + found_.assign_vector((self.stashi.get_vector(base, size) == index.expand_to_vector(size)) + & self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) @lib.for_range_opt(self.t - 1) def _(i): - swap(self.stash, 0, i, found_[i]) - swap(self.stashi, 0, i, found_[i]) + swap(self.stash, 0, i + 1, found_[i+1]) + swap(self.stashi, 0, i + 1, found_[i+1]) found.write(sum(found_)) else: @lib.for_range_opt(self.t - 1) def _(i): nonlocal found found_: B = index == self.stashi[i + 1] - swap(self.stash, 0, i, found_) - swap(self.stashi, 0, i, found_) + swap(self.stash, 0, i + 1, found_) + swap(self.stashi, 0, i + 1, found_) found |= found_ # If the item was not in the stash, we move the unknown and unimportant # stash[0] out of the way (to the end of the stash) @@ -183,11 +186,11 @@ def _(i): if debug: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln(' Found item in stash') + lib.print_ln('\tFound item in stash') @lib.else_ def __(): - lib.print_ln(' Item not in stash') - lib.print_ln(' Moved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + lib.print_ln('\tItem not in stash') + lib.print_ln('\tMoved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash From 8af345a7138cebbc2ba120b9b8b840f4f8dd688c Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 17:35:22 +0200 Subject: [PATCH 117/221] Fix improper multi-dimensionality in SqrtORAM --- Compiler/sqrt_oram.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index c4b14f99b..5a5e24ddc 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -82,7 +82,17 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. period (int): Leave at None, this parameter is used to recursively pass down the top-level period. """ - self.n = len(data) + if isinstance(data, MultiArray): + self.shuffle = data + self.n = len(data) + elif isinstance(data, sint): + self.n = math.ceil(len(data) // entry_length) + if (len(data) % entry_length != 0): + raise Exception('Data incorrectly padded.') + self.shuffle = MultiArray((self.n, entry_length), value_type=value_type) + self.shuffle.assign_part_vector(data.get_vector()) + else: + raise Exception("Incorrect format.") self.value_type = value_type if value_type != sint and value_type != sbitint: @@ -95,13 +105,6 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) self.shuffle_used = cint.Array(self.n) # Random permutation on the data - if isinstance(data, MultiArray): - self.shuffle = data - elif isinstance(data, sint): - self.shuffle = MultiArray((self.n, self.entry_length), value_type=value_type) - self.shuffle.assign_vector(data.get_vector()) - else: - raise Exception("Incorrect format.") self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given @@ -124,8 +127,8 @@ def read(self, index: T): return self.access(index, self.bit_type(False), *value) @lib.method_block - def write(self, index: T, value: T): - lib.runtime_error_if(value.size != self.entry_length, "A block must be of size entry_length") + def write(self, index: T, *value: T): + lib.runtime_error_if(len(value) != self.entry_length, "A block must be of size entry_length") self.access(index, self.bit_type(True), *value) __getitem__ = read From 33299e78a58553c8d61706a40fcb257161db0323 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 17:35:48 +0200 Subject: [PATCH 118/221] Add multithreading to LinearPositionMap in SqrtORAM --- Compiler/sqrt_oram.py | 55 +++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 5a5e24ddc..5de1174f8 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -22,9 +22,7 @@ program = Program.prog debug = False -n_parallel = 1024 n_threads = 8 - multithreading = True def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): @@ -486,26 +484,47 @@ def get_position(self, logical_address: T, fake: B) -> _clear: This method corresponds to GetPosBase in the paper. """ super().get_position(logical_address, fake) - fake = self.bit_type(fake) - - # In order to get an address at secret logical_address, - # we need to perform a linear scan. - linear_scan = self.bit_type.Array(self.n) - @lib.for_range_opt(self.n) - def _(i): - linear_scan[i] = logical_address == i + fake = MemValue(self.bit_type(fake)) + logical_address = MemValue(logical_address) p: MemValue = MemValue(self.index_type(-1)) done: B = self.bit_type(False) - @lib.for_range_opt(self.n) - def _(j): - nonlocal done, fake - condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ - .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) - p.write(condition.if_else(self.physical[j], p)) - self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) - done = self.bit_type(condition.if_else(self.bit_type(True), done)) + if multithreading: + conditions:Array = self.bit_type.Array(self.n) + conditions.assign_all(0) + + @lib.for_range_opt_multithread(8, self.n) + def condition_i(i): + conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i) + + @lib.for_range_opt(self.n) + def _(i): + nonlocal done + conditions[i] &= done.bit_not() + done |= conditions[i] + @lib.map_sum_opt(8, self.n, [self.value_type]) + def calc_p(i): + return self.physical[i] * conditions[i] + p.write(calc_p()) + + self.used.assign(self.used[:] | conditions[:]) + else: + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + linear_scan = self.bit_type.Array(self.n) + @lib.for_range_opt(self.n) + def _(i): + linear_scan[i] = logical_address == i + + @lib.for_range_opt(self.n) + def __(j): + nonlocal done, fake + condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ + .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) + p.write(condition.if_else(self.physical[j], p)) + self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) + done = self.bit_type(condition.if_else(self.bit_type(True), done)) if debug: @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical))) From 2cd263dad0344355ff4837d04c26fbf93bd4abc7 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Mon, 1 Aug 2022 17:58:45 +0200 Subject: [PATCH 119/221] Improve multithreading and remove non-multithreaded code --- Compiler/sqrt_oram.py | 477 +++++++++++++++++++++++++++++------------- 1 file changed, 332 insertions(+), 145 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 5de1174f8..a757c045e 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -3,10 +3,10 @@ import math from typing import Any, Generic, Type, TypeVar -from Compiler.program import Program from Compiler import util from Compiler import library as lib from Compiler.GC.types import cbit, sbit, sbitint, sbits +from Compiler.program import Program from Compiler.types import ( Array, MemValue, @@ -14,16 +14,28 @@ _clear, _secret, cint, + regint, sint, sintbit, - regint ) +from oram import get_n_threads program = Program.prog -debug = False +debug = True +trace = True n_threads = 8 -multithreading = True +n_parallel = 1 + +def get_n_threads(n_loops): + if n_threads is None: + if n_loops > 2048: + return 8 + else: + return None + else: + return n_threads + def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): """Swap two positions in an Array if a condition is met. @@ -43,9 +55,11 @@ def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: array[pos_b] = cond.if_else(array[pos_a], array[pos_b]) array[pos_a] = cond.if_else(temp, array[pos_a]) + T = TypeVar("T", sint, sbitint) B = TypeVar("B", sintbit, sbit) + class SqrtOram(Generic[T, B]): # TODO: Preferably this is an Array of vectors, but this is currently not supported # One should regard these structures as Arrays where an entry may hold more @@ -75,7 +89,7 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type """Initialize a new Oblivious RAM using the "Square-Root" algorithm. Args: - data (MultiArray): The data with which to initialize the ORAM. For all intents and purposes, data is regarded as a one-dimensional Array. However, one may provide a MultiArray such that every "block" can hold multiple elements (an Array). + data (MultiArray): The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array). value_type (sint): The secret type to use, defaults to sint. k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. period (int): Leave at None, this parameter is used to recursively pass down the top-level period. @@ -87,7 +101,8 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type self.n = math.ceil(len(data) // entry_length) if (len(data) % entry_length != 0): raise Exception('Data incorrectly padded.') - self.shuffle = MultiArray((self.n, entry_length), value_type=value_type) + self.shuffle = MultiArray( + (self.n, entry_length), value_type=value_type) self.shuffle.assign_part_vector(data.get_vector()) else: raise Exception("Incorrect format.") @@ -96,19 +111,23 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type if value_type != sint and value_type != sbitint: raise Exception("The value_type must be either sint or sbitint") self.bit_type: Type[B] = value_type.bit_type - self.index_type = value_type.get_type(util.log2(self.n)) + self.index_size = util.log2(self.n) + self.index_type = value_type.get_type(self.index_size) self.entry_length = entry_length if debug: - lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) + lib.print_ln( + 'Initializing SqrtORAM of size %s at depth %s', self.n, k) self.shuffle_used = cint.Array(self.n) # Random permutation on the data - self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) + self.shufflei = Array.create_from( + [self.index_type(i) for i in range(self.n)]) permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given # upon recursion, the period should stay the same ("in sync"), # therefore it can be passed as a constructor parameter - self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period + self.T = int(math.ceil( + math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period if debug and not period: lib.print_ln('Period set to %s', self.T) # Initialize position map (recursive oram) @@ -119,29 +138,21 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type self.stashi = Array(self.T, value_type=value_type) self.t = MemValue(cint(0)) - @lib.method_block - def read(self, index: T): - value = self.value_type(0, size=self.entry_length) - return self.access(index, self.bit_type(False), *value) - - @lib.method_block - def write(self, index: T, *value: T): - lib.runtime_error_if(len(value) != self.entry_length, "A block must be of size entry_length") - self.access(index, self.bit_type(True), *value) - - __getitem__ = read - __setitem__ = write + # Initialize temp variables needed during the computation + self.found_ = self.bit_type.Array(size=self.T) @lib.method_block def access(self, index: T, write: B, *value: T): - if debug: + if trace: @lib.if_e(write.reveal() == 1) def _(): lib.print_ln('Writing to secret index %s', index.reveal()) + @lib.else_ def __(): lib.print_ln('Reading from secret index %s', index.reveal()) - value = self.value_type(value) + + value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length) index = MemValue(index) # Refresh if we have performed T (period) accesses @@ -150,87 +161,247 @@ def _(): self.refresh() found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in a bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ + self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + # Store the stash item into the result if found + # If the item is not in the stash, the result will simple remain 0 + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + # If we are writing, we need to add the value + self.stash[i] += write * access_here * (value - entry) + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('\tFound item in stash') + + @lib.else_ + def __(): + lib.print_ln('\tDid not find item in stash') - # Scan through the stash - @lib.if_(self.t > 0) + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. + self.shuffle_used[physical_address] = cbit(True) + + # If the item was not found in the stash + # ...we update the item in the shuffle + self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:]) + # ...and the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + + if trace: + @lib.if_((write * found.bit_not()).reveal()) + def _(): + lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) + + lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + return result + + @lib.method_block + def write(self, index: T, *value: T): + if trace: + lib.print_ln('Writing to secret index %s', index.reveal()) + + value = self.value_type(value) + index = MemValue(index) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) def _(): - nonlocal found - found |= index == self.stashi[0] - # We ensure that if the item is found in stash, it ends up in the first - # position (more importantly, a fixed position) of the stash - # This allows us to keep track of it in an oblivious manner - if multithreading: - found_ = self.bit_type.Array(size=self.T) - @lib.multithread(1, self.T) - def _(base, size): - found_.assign_vector((self.stashi.get_vector(base, size) == index.expand_to_vector(size)) - & self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), - base=base) - @lib.for_range_opt(self.t - 1) - def _(i): - swap(self.stash, 0, i + 1, found_[i+1]) - swap(self.stashi, 0, i + 1, found_[i+1]) - found.write(sum(found_)) - else: - @lib.for_range_opt(self.t - 1) - def _(i): - nonlocal found - found_: B = index == self.stashi[i + 1] - swap(self.stash, 0, i + 1, found_) - swap(self.stashi, 0, i + 1, found_) - found |= found_ - # If the item was not in the stash, we move the unknown and unimportant - # stash[0] out of the way (to the end of the stash) - swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) - swap(self.stashi, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) + self.refresh() - if debug: + found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in an bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ + self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + # We update the stash value + self.stash[i] += access_here * (value - entry) + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: @lib.if_e(found.reveal() == 1) def _(): lib.print_ln('\tFound item in stash') + @lib.else_ def __(): - lib.print_ln('\tItem not in stash') - lib.print_ln('\tMoved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + lib.print_ln('\tDid not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. self.shuffle_used[physical_address] = cbit(True) - # If the item was in the stash (thus currently residing in stash[0]), - # we place the random item retrieved from the shuffle at the end of the stash - self.stash[self.t].assign(found.if_else( - self.shuffle[physical_address][:], - self.stash[self.t][:])) - self.stashi[self.t] = found.if_else( - self.shufflei[physical_address], - self.stashi[self.t]) - # If the item was not found in the stash, we place the item retrieved - # from the shuffle (the item we are actually looking for) in stash[0] - self.stash[0].assign(found.bit_not().if_else( - self.shuffle[physical_address][:], - self.stash[0][:])) - self.stashi[0] = found.bit_not().if_else( - self.shufflei[physical_address], - self.stashi[0]) + # If the item was not found in the stash + # ...we update the item in the shuffle + self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:]) + # ...and the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + + if trace: + @lib.if_(found.bit_not().reveal()) + def _(): + lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) - if debug: + lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + return result + + @lib.method_block + def read(self, index: T, *value: T): + if trace: + lib.print_ln('Reading from secret index %s', index.reveal()) + value = self.value_type(value) + index = MemValue(index) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + self.refresh() + + found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in a bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ + self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + # Store the stash item into the result if found + # If the item is not in the stash, the result will simple remain 0 + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[t]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) + lib.print_ln('\tFound item in stash') + @lib.else_ def __(): - lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[0]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) + lib.print_ln('\tDid not find item in stash') + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. + self.shuffle_used[physical_address] = cbit(True) + + # If the item was not found in the stash + # the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + if trace: + lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) # Increase the "time" (i.e. access count in current period) self.t.iadd(1) - self.stash[0].assign(write.if_else(value, self.stash[0][:])) - value=write.bit_not().if_else(self.stash[0][:], value) - return value + return result + __getitem__ = read + __setitem__ = write @lib.method_block def shuffle_the_shuffle(self): @@ -242,12 +413,15 @@ def shuffle_the_shuffle(self): # Random permutation on n elements random_shuffle = sint.get_secure_shuffle(self.n) - if debug: lib.print_ln('\tGenerated shuffle') + if trace: + lib.print_ln('\tGenerated shuffle') # Apply the random permutation self.shuffle.secure_permute(random_shuffle) - if debug: lib.print_ln('\tShuffled shuffle') + if trace: + lib.print_ln('\tShuffled shuffle') self.shufflei.secure_permute(random_shuffle) - if debug: lib.print_ln('\tShuffled shuffle indexes') + if trace: + lib.print_ln('\tShuffled shuffle indexes') # Calculate the permutation that would have produced the newly produced # shuffle order. This can be calculated by regarding the logical # indexes (shufflei) as a permutation and calculating its inverse, @@ -256,7 +430,8 @@ def shuffle_the_shuffle(self): # random_shuffle, as the shuffle may already be out of order (e.g. when # refreshing). permutation = MemValue(self.shufflei[:].inverse_permutation()) - if debug: lib.print_ln('\tCalculated inverse permutation') + if trace: + lib.print_ln('\tCalculated inverse permutation') return permutation @lib.method_block @@ -266,10 +441,12 @@ def refresh(self): This must happen after T (period) accesses to the ORAM.""" - if debug: lib.print_ln('Refreshing SqrtORAM') + if trace: + lib.print_ln('Refreshing SqrtORAM') # Shuffle and emtpy the stash, and store elements back into shuffle - j = MemValue(cint(0,size=1)) + j = MemValue(cint(0, size=1)) + @lib.for_range_opt(self.n) def _(i): @lib.if_(self.shuffle_used[i]) @@ -301,13 +478,14 @@ def reinitialize(self, *data: T): self.shufflei.assign([self.index_type(i) for i in range(self.n)]) # Reset the clock self.t.write(0) - # Reset shuffle_used + # Reset shuffle_used self.shuffle_used.assign_all(0) # Note that the self.shuffle is actually a MultiArray # This structure is preserved while overwriting the values using # assign_vector - self.shuffle.assign_vector(self.value_type(data, size=self.n * self.entry_length)) + self.shuffle.assign_vector(self.value_type( + data, size=self.n * self.entry_length)) permutation = self.shuffle_the_shuffle() self.position_map.reinitialize(*permutation) @@ -316,13 +494,13 @@ class PositionMap(Generic[T, B]): PACK_LOG: int = 3 PACK: int = 1 << PACK_LOG - n: int # n in the paper - depth: int # k in the paper + n: int # n in the paper + depth: int # k in the paper value_type: Type[T] - def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None: + def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None: self.n = n - self.depth=MemValue(cint(k)) + self.depth = MemValue(cint(k)) self.value_type = value_type self.bit_type = value_type.bit_type self.index_type = self.value_type.get_type(util.log2(n)) @@ -330,8 +508,9 @@ def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None: @abstractmethod def get_position(self, logical_address: _secret, fake: B) -> Any: """Retrieve the block at the given (secret) logical address.""" - if debug: - lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) + if trace: + lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, + self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) def reinitialize(self, *permutation: T): """Reinitialize this PositionMap. @@ -352,11 +531,13 @@ def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = s if n / PositionMap.PACK <= period: if debug: - lib.print_ln('Initializing LinearPositionMap at depth %s of size %s', k, n) + lib.print_ln( + 'Initializing LinearPositionMap at depth %s of size %s', k, n) res = LinearPositionMap(permutation, value_type, k=k) else: if debug: - lib.print_ln('Initializing RecursivePositionMap at depth %s of size %s', k, n) + lib.print_ln( + 'Initializing RecursivePositionMap at depth %s of size %s', k, n) res = RecursivePositionMap(permutation, period, value_type, k=k) return res @@ -364,7 +545,7 @@ def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = s class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): - def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k:int=-1) -> None: + def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k: int = -1) -> None: PositionMap.__init__(self, len(permutation), k=k) pack = PositionMap.PACK @@ -377,7 +558,8 @@ def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, permutation[i*pack:(i+1)*pack]) # TODO: Should this be n or packed_size? - SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth) + SqrtOram.__init__(self, packed_structure, value_type=value_type, + period=period, entry_length=pack, k=self.depth) @lib.method_block def get_position(self, logical_address: T, fake: B) -> _clear: @@ -389,7 +571,8 @@ def get_position(self, logical_address: T, fake: B) -> _clear: # The item at logical_address # will be in block with index h (block.) # at position l in block.data (block.data) - h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(logical_address).right_shift(pack_log, program.bit_length))) + h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)( + logical_address).right_shift(pack_log, program.bit_length))) l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) # The resulting physical address @@ -401,32 +584,37 @@ def get_position(self, logical_address: T, fake: B) -> _clear: # First we scan the stash for the block we need condition1 = self.bit_type.Array(self.T) + @lib.for_range_opt_multithread(8, self.T) def _(i): condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t) found = sum(condition1) # Once a block is found, we use condition2 to pick the correct item from that block - condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack)) + condition2 = Array.create_from( + regint.inc(pack) == l.expand_to_vector(pack)) # condition3 combines condition1 & condition2, only returning true at stash[h][l] condition3 = self.bit_type.Array(self.T * pack) + @lib.for_range_opt_multithread(8, [self.T, pack]) def _(i, j): condition3[i*pack + j] = condition1[i] & condition2[j] # Finally we use condition3 to conditionally write p + @lib.for_range(self.t) def _(i): @lib.for_range(pack) def _(j): p.write(condition3[i*pack + j].if_else(self.stash[i][j], p)) - if debug: + if trace: @lib.if_(condition1[i].reveal() == 1) def _(): - lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal()) + lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal( + ), self.stash[i].reveal()) # Then we try and retrieve the item from the shuffle (the actual memory) - if debug: + if trace: @lib.if_(found.reveal() == 0) def _(): lib.print_ln('\t%s Position not in stash', self.depth) @@ -438,15 +626,16 @@ def _(): # The block retrieved from the shuffle block_p_prime: Array = self.shuffle[p_prime] - if debug: + if trace: @lib.if_e(found.reveal() == 0) def _(): lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', - self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + @lib.else_ def __(): lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)', - self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) # We add the retrieved block from the shuffle to the stash self.stash[self.t].assign(block_p_prime[:]) @@ -458,7 +647,9 @@ def __(): condition: B = self.bit_type(fake.bit_or(found.bit_not())) # Retrieve l'th item from block # l is secret, so we must use linear scan - hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack)) + hit = Array.create_from((regint.inc(pack) == l.expand_to_vector( + pack)) & condition.expand_to_vector(pack)) + @lib.for_range_opt(pack) def _(i): p.write((hit[i]).if_else(block_p_prime[i], p)) @@ -469,67 +660,63 @@ def _(i): def reinitialize(self, *permutation: T): SqrtOram.reinitialize(self, *permutation) + class LinearPositionMap(PositionMap): physical: Array used: Array - def __init__(self, data: Array, value_type: Type[T] = sint, k:int =-1) -> None: + def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None: PositionMap.__init__(self, len(data), value_type, k=k) self.physical = data self.used = self.bit_type.Array(self.n) + # Initialize random temp variables needed during the computation + self.physical_demux: Array = self.bit_type.Array(self.n) + @lib.method_block def get_position(self, logical_address: T, fake: B) -> _clear: """ This method corresponds to GetPosBase in the paper. """ super().get_position(logical_address, fake) + fake = MemValue(self.bit_type(fake)) logical_address = MemValue(logical_address) p: MemValue = MemValue(self.index_type(-1)) done: B = self.bit_type(False) - if multithreading: - conditions:Array = self.bit_type.Array(self.n) - conditions.assign_all(0) - - @lib.for_range_opt_multithread(8, self.n) - def condition_i(i): - conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i) - - @lib.for_range_opt(self.n) - def _(i): - nonlocal done - conditions[i] &= done.bit_not() - done |= conditions[i] - @lib.map_sum_opt(8, self.n, [self.value_type]) - def calc_p(i): - return self.physical[i] * conditions[i] - p.write(calc_p()) - - self.used.assign(self.used[:] | conditions[:]) - else: - # In order to get an address at secret logical_address, - # we need to perform a linear scan. - linear_scan = self.bit_type.Array(self.n) - @lib.for_range_opt(self.n) - def _(i): - linear_scan[i] = logical_address == i - - @lib.for_range_opt(self.n) - def __(j): - nonlocal done, fake - condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ - .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) - p.write(condition.if_else(self.physical[j], p)) - self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) - done = self.bit_type(condition.if_else(self.bit_type(True), done)) + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + self.physical_demux.assign_all(0) + @lib.for_range_opt_multithread(8, self.n) + def condition_i(i): + self.physical_demux.assign((self.bit_type(fake).bit_not() + & self.bit_type(logical_address == i)) | (fake + & self.used[i].bit_not()), base=i) + + # In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False) + # We only need once, so we pick the first one we find + @lib.for_range_opt(self.n) + def _(i): + nonlocal done + self.physical_demux[i] &= done.bit_not() + done |= self.physical_demux[i] - if debug: + # Retrieve the value from the physical memory obliviously + @lib.map_sum_opt(8, self.n, [self.value_type]) + def calc_p(i): + return self.physical[i] * self.physical_demux[i] + p.write(calc_p()) + + # Update self.used + self.used.assign(self.used[:] | self.physical_demux[:]) + + if trace: @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical))) def _(): - lib.runtime_error('%s Did not find requested logical_address in shuffle, something went wrong.', self.depth) + lib.runtime_error( + '%s Did not find requested logical_address in shuffle, something went wrong.', self.depth) return p.reveal() From fb5871a2f8a54bae8af1ea64114496b962781e82 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 11 Aug 2022 14:41:20 +0200 Subject: [PATCH 120/221] Add allow_memory_allocation option to SqrtORAM Also remove unused swap function in SqrtORAM --- Compiler/sqrt_oram.py | 311 ++++++++++++++++++++++++------------------ 1 file changed, 179 insertions(+), 132 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index a757c045e..d732e2f26 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -1,32 +1,34 @@ from __future__ import annotations -from abc import abstractmethod + import math +from abc import abstractmethod from typing import Any, Generic, Type, TypeVar -from Compiler import util from Compiler import library as lib +from Compiler import util from Compiler.GC.types import cbit, sbit, sbitint, sbits from Compiler.program import Program -from Compiler.types import ( - Array, - MemValue, - MultiArray, - _clear, - _secret, - cint, - regint, - sint, - sintbit, -) -from oram import get_n_threads +from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint, + regint, sint, sintbit) +from oram import demux_array, get_n_threads program = Program.prog -debug = True -trace = True -n_threads = 8 +# Adds messages on completion of heavy computation steps +debug = False +# Finer grained trace of steps that the ORAM performs +# + runtime error checks +# Warning: reveals information and makes the computation insecure +trace = False + +n_threads = 16 n_parallel = 1 +# Avoids any memory allocation +# This prevents some optimizations but allows for using the ORAMs outside of the main tape +allow_memory_allocation = True + + def get_n_threads(n_loops): if n_threads is None: if n_loops > 2048: @@ -37,25 +39,6 @@ def get_n_threads(n_loops): return n_threads -def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): - """Swap two positions in an Array if a condition is met. - - Args: - array (Array | MultiArray): The array in which to swap the first and second position - pos_a (int | cint): The first position - pos_b (int | cint): The second position - cond (sintbit | sbit): The condition determining whether to swap - """ - if isinstance(array, MultiArray): - temp = array[pos_b][:] - array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:])) - array[pos_a].assign(cond.if_else(temp, array[pos_a][:])) - if isinstance(array, Array): - temp = array[pos_b] - array[pos_b] = cond.if_else(array[pos_a], array[pos_b]) - array[pos_a] = cond.if_else(temp, array[pos_a]) - - T = TypeVar("T", sint, sbitint) B = TypeVar("B", sintbit, sbit) @@ -85,7 +68,7 @@ class SqrtOram(Generic[T, B]): # the stash) t: cint - def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: + def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True) -> None: """Initialize a new Oblivious RAM using the "Square-Root" algorithm. Args: @@ -94,6 +77,9 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. period (int): Leave at None, this parameter is used to recursively pass down the top-level period. """ + global debug, allow_memory_allocation + + # Correctly initialize the shuffle (memory) depending on the type of data if isinstance(data, MultiArray): self.shuffle = data self.n = len(data) @@ -107,9 +93,12 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type else: raise Exception("Incorrect format.") - self.value_type = value_type + # Only sint is supported if value_type != sint and value_type != sbitint: raise Exception("The value_type must be either sint or sbitint") + + # Set derived constants + self.value_type = value_type self.bit_type: Type[B] = value_type.bit_type self.index_size = util.log2(self.n) self.index_type = value_type.get_type(self.index_size) @@ -118,11 +107,11 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type if debug: lib.print_ln( 'Initializing SqrtORAM of size %s at depth %s', self.n, k) + self.shuffle_used = cint.Array(self.n) # Random permutation on the data self.shufflei = Array.create_from( [self.index_type(i) for i in range(self.n)]) - permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given # upon recursion, the period should stay the same ("in sync"), # therefore it can be passed as a constructor parameter @@ -130,8 +119,21 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period if debug and not period: lib.print_ln('Period set to %s', self.T) + + # Here we allocate the memory for the permutation + # Note that self.shuffle_the_shuffle mutates this field + # Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading + self.permutation = Array.create_from( + [self.index_type(i) for i in range(self.n)]) + # We allow the caller to postpone the initialization of the shuffle + # This is the most expensive operation, and can be done in a thread (only if you know what you're doing) + # Note that if you do not initialize, the ORAM is insecure + if initialize: + self.shuffle_the_shuffle() + else: + print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.') # Initialize position map (recursive oram) - self.position_map = PositionMap.create(permutation, k + 1, self.T) + self.position_map = PositionMap.create(self.permutation, k + 1, self.T) # Initialize stash self.stash = MultiArray((self.T, entry_length), value_type=value_type) @@ -140,19 +142,28 @@ def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type # Initialize temp variables needed during the computation self.found_ = self.bit_type.Array(size=self.T) + self.j = MemValue(cint(0, size=1)) + + # To prevent the compiler from recompiling the same code over and over again, we should use @method_block + # However, @method_block requires allocation (of return address), which is not allowed when not in the main thread + # Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread + self.shuffle_the_shuffle = lib.method_block(self.shuffle_the_shuffle) if allow_memory_allocation else self.shuffle_the_shuffle + self.refresh = lib.method_block(self.refresh) if allow_memory_allocation else self.refresh @lib.method_block def access(self, index: T, write: B, *value: T): + global trace,n_parallel if trace: @lib.if_e(write.reveal() == 1) def _(): - lib.print_ln('Writing to secret index %s', index.reveal()) + lib.print_ln(' Writing to secret index %s', index.reveal()) @lib.else_ def __(): - lib.print_ln('Reading from secret index %s', index.reveal()) + lib.print_ln(' Reading from secret index %s', index.reveal()) - value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length) + value = self.value_type(value, size=self.entry_length).get_vector( + 0, size=self.entry_length) index = MemValue(index) # Refresh if we have performed T (period) accesses @@ -171,8 +182,9 @@ def _(): @lib.multithread(get_n_threads(self.T), self.T) def _(base, size): self.found_.assign_vector( - (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ - self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), base=base) # To determine whether the item is found in the stash, we simply @@ -200,11 +212,11 @@ def stash_item(i): if trace: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln('\tFound item in stash') + lib.print_ln(' Found item in stash') @lib.else_ def __(): - lib.print_ln('\tDid not find item in stash') + lib.print_ln(' Did not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -215,7 +227,8 @@ def __(): # If the item was not found in the stash # ...we update the item in the shuffle - self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:]) + self.shuffle[physical_address] += write * \ + found.bit_not() * (value - self.shuffle[physical_address][:]) # ...and the item retrieved from the shuffle is our result result += self.shuffle[physical_address] * found.bit_not() # We append the newly retrieved item to the stash @@ -225,10 +238,8 @@ def __(): if trace: @lib.if_((write * found.bit_not()).reveal()) def _(): - lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) - - lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, - self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal( + ), self.shuffle[physical_address].reveal(), physical_address) # Increase the "time" (i.e. access count in current period) self.t.iadd(1) @@ -237,8 +248,9 @@ def _(): @lib.method_block def write(self, index: T, *value: T): + global trace, n_parallel if trace: - lib.print_ln('Writing to secret index %s', index.reveal()) + lib.print_ln(' Writing to secret index %s', index.reveal()) value = self.value_type(value) index = MemValue(index) @@ -259,8 +271,9 @@ def _(): @lib.multithread(get_n_threads(self.T), self.T) def _(base, size): self.found_.assign_vector( - (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ - self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), base=base) # To determine whether the item is found in the stash, we simply @@ -286,11 +299,11 @@ def stash_item(i): if trace: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln('\tFound item in stash') + lib.print_ln(' Found item in stash') @lib.else_ def __(): - lib.print_ln('\tDid not find item in stash') + lib.print_ln(' Did not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -301,7 +314,8 @@ def __(): # If the item was not found in the stash # ...we update the item in the shuffle - self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:]) + self.shuffle[physical_address] += found.bit_not() * \ + (value - self.shuffle[physical_address][:]) # ...and the item retrieved from the shuffle is our result result += self.shuffle[physical_address] * found.bit_not() # We append the newly retrieved item to the stash @@ -311,9 +325,10 @@ def __(): if trace: @lib.if_(found.bit_not().reveal()) def _(): - lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) + lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal( + ), self.shuffle[physical_address].reveal(), physical_address) - lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) # Increase the "time" (i.e. access count in current period) @@ -323,14 +338,20 @@ def _(): @lib.method_block def read(self, index: T, *value: T): + global debug, trace, n_parallel if trace: - lib.print_ln('Reading from secret index %s', index.reveal()) + lib.print_ln(' Reading from secret index %s', index.reveal()) + value = self.value_type(value) index = MemValue(index) # Refresh if we have performed T (period) accesses @lib.if_(self.t == self.T) def _(): + if debug: + lib.print_ln(' Refreshing SqrtORAM') + lib.print_ln(' t=%s according to me', self.t) + self.refresh() found: B = MemValue(self.bit_type(False)) @@ -344,8 +365,9 @@ def _(): @lib.multithread(get_n_threads(self.T), self.T) def _(base, size): self.found_.assign_vector( - (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ - self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), base=base) # To determine whether the item is found in the stash, we simply @@ -371,11 +393,11 @@ def stash_item(i): if trace: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln('\tFound item in stash') + lib.print_ln(' Found item in stash') @lib.else_ def __(): - lib.print_ln('\tDid not find item in stash') + lib.print_ln(' Did not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -392,7 +414,7 @@ def __(): self.stashi[self.t] = self.shufflei[physical_address] if trace: - lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) # Increase the "time" (i.e. access count in current period) @@ -403,25 +425,36 @@ def __(): __getitem__ = read __setitem__ = write - @lib.method_block - def shuffle_the_shuffle(self): + def shuffle_the_shuffle(self) -> None: """Permute the memory using a newly generated permutation and return the permutation that would generate this particular shuffling. This permutation is needed to know how to map logical addresses to physical addresses, and is used as such by the postition map.""" + global trace # Random permutation on n elements random_shuffle = sint.get_secure_shuffle(self.n) if trace: - lib.print_ln('\tGenerated shuffle') + lib.print_ln(' Generated shuffle') # Apply the random permutation self.shuffle.secure_permute(random_shuffle) if trace: - lib.print_ln('\tShuffled shuffle') + lib.print_ln(' Shuffled shuffle') self.shufflei.secure_permute(random_shuffle) if trace: - lib.print_ln('\tShuffled shuffle indexes') + lib.print_ln(' Shuffled shuffle indexes') + + if trace: + # If shufflei does not contain exactly the indices [i for i in + # range(self.n)], the underlying waksman network of + # 'inverse_permutation' will hang. + tmp_shuffli = Array.create_from(self.shufflei[:]) + @lib.if_(sum(lib.sort(tmp_shuffli)[:] == Array.create_from([cint(i) for i in range(self.n)])[:]).reveal() != self.n) + def _(): + lib.print_ln( + 'Shufflei is corrupted! You have found a bug in the implementation :c\nThe computation will now hang...') + # Calculate the permutation that would have produced the newly produced # shuffle order. This can be calculated by regarding the logical # indexes (shufflei) as a permutation and calculating its inverse, @@ -429,45 +462,45 @@ def shuffle_the_shuffle(self): # this is not necessarily equal to the inverse of the above generated # random_shuffle, as the shuffle may already be out of order (e.g. when # refreshing). - permutation = MemValue(self.shufflei[:].inverse_permutation()) + self.permutation.assign(self.shufflei[:].inverse_permutation()) if trace: - lib.print_ln('\tCalculated inverse permutation') - return permutation + lib.print_ln(' Calculated inverse permutation') - @lib.method_block def refresh(self): """Refresh the ORAM by reinserting the stash back into the shuffle, and reshuffling the shuffle. - This must happen after T (period) accesses to the ORAM.""" - - if trace: - lib.print_ln('Refreshing SqrtORAM') + This must happen on the T'th (period) accesses to the ORAM.""" + self.j.write(0) # Shuffle and emtpy the stash, and store elements back into shuffle - j = MemValue(cint(0, size=1)) @lib.for_range_opt(self.n) def _(i): @lib.if_(self.shuffle_used[i]) def _(): - nonlocal j - self.shuffle[i] = self.stash[j] - self.shufflei[i] = self.stashi[j] - j += 1 + self.shuffle[i] = self.stash[self.j] + self.shufflei[i] = self.stashi[self.j] + self.j += 1 # Reset the clock self.t.write(0) # Reset shuffle_used - self.shuffle_used.assign_all(0) + global allow_memory_allocation + if allow_memory_allocation: + self.shuffle_used.assign_all(0) + else: + @lib.for_range_opt(self.n) + def _(i): + self.shuffle_used[i] = cint(0) # Reinitialize position map - permutation = self.shuffle_the_shuffle() + self.shuffle_the_shuffle() # Note that we skip here the step of "packing" the permutation. # Since the underlying memory of the position map is already aligned in # this packed structure, we can simply overwrite the memory while # maintaining the structure. - self.position_map.reinitialize(*permutation) + self.position_map.reinitialize(*self.permutation) @lib.method_block def reinitialize(self, *data: T): @@ -478,7 +511,7 @@ def reinitialize(self, *data: T): self.shufflei.assign([self.index_type(i) for i in range(self.n)]) # Reset the clock self.t.write(0) - # Reset shuffle_used + # Reset shuffle_used self.shuffle_used.assign_all(0) # Note that the self.shuffle is actually a MultiArray @@ -486,8 +519,10 @@ def reinitialize(self, *data: T): # assign_vector self.shuffle.assign_vector(self.value_type( data, size=self.n * self.entry_length)) - permutation = self.shuffle_the_shuffle() - self.position_map.reinitialize(*permutation) + # Note that this updates self.permutation (see constructor for explanation) + self.shuffle_the_shuffle() + self.position_map.reinitialize(*self.permutation) + class PositionMap(Generic[T, B]): @@ -508,8 +543,9 @@ def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None: @abstractmethod def get_position(self, logical_address: _secret, fake: B) -> Any: """Retrieve the block at the given (secret) logical address.""" + global trace if trace: - lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, + lib.print_ln(' %s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) def reinitialize(self, *permutation: T): @@ -529,6 +565,7 @@ def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = s a LinearPositionMap.""" n = len(permutation) + global debug if n / PositionMap.PACK <= period: if debug: lib.print_ln( @@ -561,6 +598,10 @@ def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth) + # Initialize random temp variables needed during the computation + self.block_index_demux: Array = self.bit_type.Array(self.T) + self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK) + @lib.method_block def get_position(self, logical_address: T, fake: B) -> _clear: super().get_position(logical_address, fake) @@ -576,50 +617,42 @@ def get_position(self, logical_address: T, fake: B) -> _clear: l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) # The resulting physical address - p = MemValue(self.index_type(0)) + p = MemValue(self.index_type(-1)) found: B = MemValue(self.bit_type(False)) # First we try and retrieve the item from the stash at position stash[h][l] # Since h and l are secret, we do this by scanning the entire stash # First we scan the stash for the block we need - condition1 = self.bit_type.Array(self.T) + self.block_index_demux.assign_all(0) - @lib.for_range_opt_multithread(8, self.T) - def _(i): - condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t) - found = sum(condition1) - # Once a block is found, we use condition2 to pick the correct item from that block - condition2 = Array.create_from( - regint.inc(pack) == l.expand_to_vector(pack)) - # condition3 combines condition1 & condition2, only returning true at stash[h][l] - condition3 = self.bit_type.Array(self.T * pack) - - @lib.for_range_opt_multithread(8, [self.T, pack]) - def _(i, j): - condition3[i*pack + j] = condition1[i] & condition2[j] - # Finally we use condition3 to conditionally write p - - @lib.for_range(self.t) + @lib.for_range_opt_multithread(get_n_threads(self.T), self.T) def _(i): - @lib.for_range(pack) - def _(j): - p.write(condition3[i*pack + j].if_else(self.stash[i][j], p)) - - if trace: - @lib.if_(condition1[i].reveal() == 1) - def _(): - lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal( - ), self.stash[i].reveal()) - - # Then we try and retrieve the item from the shuffle (the actual memory) - + self.block_index_demux[i] = ( + self.stashi[i] == h) & self.bit_type(i < self.t) + # We can determine if the 'index' is in the stash by checking the + # block_index_demux array + found = sum(self.block_index_demux) + # Once a block is found, we use the following condition to pick the correct item from that block + demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux) + + # Finally we use the conditions to conditionally write p + @lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type]) + def p_(i): + # We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum + # Therefore we include the check (i < self.t) + return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i < self.t) + p.write(p_()) + + global trace if trace: @lib.if_(found.reveal() == 0) def _(): - lib.print_ln('\t%s Position not in stash', self.depth) + lib.print_ln(' %s Position not in stash', self.depth) - # Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle + # Then we try and retrieve the item from the shuffle (the actual memory) + # Depending on whether we found the item in the stash, we either + # block 'h' in which 'index' resides, or a random block from the shuffle p_prime = self.position_map.get_position(h, found) self.shuffle_used[p_prime] = cbit(True) @@ -629,12 +662,12 @@ def _(): if trace: @lib.if_e(found.reveal() == 0) def _(): - lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', + lib.print_ln(' %s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) @lib.else_ def __(): - lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)', + lib.print_ln(' %s Retrieved dummy stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) # We add the retrieved block from the shuffle to the stash @@ -680,6 +713,13 @@ def get_position(self, logical_address: T, fake: B) -> _clear: """ super().get_position(logical_address, fake) + global trace + if trace: + @lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal()) + def _(): + lib.runtime_error( + 'logical_address must lie between 0 and self.n - 1') + fake = MemValue(self.bit_type(fake)) logical_address = MemValue(logical_address) @@ -689,11 +729,12 @@ def get_position(self, logical_address: T, fake: B) -> _clear: # In order to get an address at secret logical_address, # we need to perform a linear scan. self.physical_demux.assign_all(0) - @lib.for_range_opt_multithread(8, self.n) + + @lib.for_range_opt_multithread(get_n_threads(self.n), self.n) def condition_i(i): - self.physical_demux.assign((self.bit_type(fake).bit_not() - & self.bit_type(logical_address == i)) | (fake - & self.used[i].bit_not()), base=i) + self.physical_demux[i] = \ + (self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \ + | (fake & self.used[i].bit_not()) # In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False) # We only need once, so we pick the first one we find @@ -704,7 +745,7 @@ def _(i): done |= self.physical_demux[i] # Retrieve the value from the physical memory obliviously - @lib.map_sum_opt(8, self.n, [self.value_type]) + @lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type]) def calc_p(i): return self.physical[i] * self.physical_demux[i] p.write(calc_p()) @@ -720,7 +761,13 @@ def _(): return p.reveal() - @lib.method_block def reinitialize(self, *data: T): self.physical.assign_vector(data) - self.used.assign_all(False) + + global allow_memory_allocation + if allow_memory_allocation: + self.used.assign_all(False) + else: + @lib.for_range_opt(self.n) + def _(i): + self.used[i] = self.bit_type(0) From a4e4baddeedefbf2a6328caf72e738a03f3930ef Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 12 Aug 2022 10:45:36 -0700 Subject: [PATCH 121/221] add new in python compile to docs --- README.md | 31 ++++++++++++++++++++++++++++++- doc/index.rst | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 18278c25f..79bf5339d 100644 --- a/README.md +++ b/README.md @@ -464,6 +464,35 @@ See the [documentation](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.circuit) for further examples. +#### Compiling programs directly in Python + +You may prefer to not have an entirely static `.mpc` file to compile, and may want to compile based on dynamic inputs. For example, you may want to be able to compile with different sizes of input data without making a code change to the `.mpc` file. To handle this, the compiler an also be directly imported, and a function can be compiled with the following interface: + +```python +# hello_world.mpc +from Compiler.library import print_ln +from Compiler.compilerLib import Compiler + +compiler = Compiler() + +@compiler.register_function('helloworld') +def hello_world(): + print_ln('hello world') + +if __name__ == "__main__": + compiler.compile_func() +``` + +You could then run this with: + +```bash +python hello_world.mpc +``` + +This is particularly useful if want to add new command line arguements specifically for your `.mpc` file. See [test_args.mpc](Programs/Source/test_args.mpc) for more details on this use case. + +Note that when using this approach, all objects provided in the high level interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file is interpreted directly by Python (instead of being read by `compile.py`. + #### Compiling and running programs from external directories Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all MP-SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: @@ -966,7 +995,7 @@ After compiling the mpc file: You can benchmark the ORAM implementation as follows: 1) Edit `Program/Source/gc_oram.mpc` to change size and to choose -Circuit ORAM or linear scan without ORAM. +Circuit ORAM or linear scan without ORAM. 2) Run `./compile.py -D gc_oram`. The `-D` argument instructs the compiler to remove dead code. This is useful for more complex programs such as this one. diff --git a/doc/index.rst b/doc/index.rst index 59caa58de..61acd0457 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -17,8 +17,7 @@ Compilation process The easiest way of using MP-SPDZ is using ``compile.py`` as described below. If you would like to run compilation directly from -Python, see ``Scripts/direct_compilation_example.py``. It contains all -the necessary setup steps. +Python, see :ref:`Direct Compilation in Python`. After putting your code in ``Program/Source/.mpc``, run the compiler from the root directory as follows @@ -140,6 +139,43 @@ computation: to the run time. +Direct Compilation in Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +You may prefer to not have an entirely static `.mpc` file to compile, +and may want to compile based on dynamic inputs. For example, you may +want to be able to compile with different sizes of input data without +making a code change to the `.mpc` file. To handle this, the compiler +an also be directly imported, and a function can be compiled with the +following interface: + +.. code-block:: python + # hello_world.mpc + from Compiler.library import print_ln + from Compiler.compilerLib import Compiler + + compiler = Compiler() + + @compiler.register_function('helloworld') + def hello_world(): + print_ln('hello world') + + if __name__ == "__main__": + compiler.compile_func() + + +You could then run this with the same args as used with `compile.py`: + +.. code-block:: bash + python hello_world.mpc + +This is particularly useful if want to add new command line arguements +specifically for your `.mpc` file. See [test_args.mpc](Programs/Source/test_args.mpc) +for more details on this use case. + +Note that when using this approach, all objects provided in the high level +interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file +is interpreted directly by Python (instead of being read by `compile.py`.) + Compilation vs run time ~~~~~~~~~~~~~~~~~~~~~~~ From f83476ab2a90e2f44f4661d6f845cc4b560306d2 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 13 Aug 2022 12:23:12 -0700 Subject: [PATCH 122/221] fix small typo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 79bf5339d..8e03710cc 100644 --- a/README.md +++ b/README.md @@ -483,7 +483,7 @@ if __name__ == "__main__": compiler.compile_func() ``` -You could then run this with: +You could then run this with the same args as used with `compile.py`: ```bash python hello_world.mpc @@ -491,7 +491,7 @@ python hello_world.mpc This is particularly useful if want to add new command line arguements specifically for your `.mpc` file. See [test_args.mpc](Programs/Source/test_args.mpc) for more details on this use case. -Note that when using this approach, all objects provided in the high level interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file is interpreted directly by Python (instead of being read by `compile.py`. +Note that when using this approach, all objects provided in the high level interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file is interpreted directly by Python (instead of being read by `compile.py`.) #### Compiling and running programs from external directories From 39b6d7e22dedc5798823d306741d33ff6635bc25 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 13 Aug 2022 17:56:24 -0700 Subject: [PATCH 123/221] allow custom_args to manaully override args --- Compiler/compilerLib.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 9e36e9f9e..ae8ca9b2f 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -12,11 +12,12 @@ class Compiler: - def __init__(self, usage=None): + def __init__(self, custom_args=None, usage=None): if usage: self.usage = usage else: self.usage = "usage: %prog [options] filename [args]" + self.custom_args = custom_args self.build_option_parser() self.VARS = {} @@ -205,7 +206,7 @@ def build_option_parser(self): self.parser = parser def parse_args(self): - self.options, self.args = self.parser.parse_args() + self.options, self.args = self.parser.parse_args(self.custom_args) if self.options.optimize_hard: print("Note that -O/--optimize-hard currently has no effect") @@ -358,7 +359,7 @@ def inner(func): return inner def compile_func(self): - if not hasattr(self, "compile_name") and hasattr(self, "compile_func"): + if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")): raise CompilerError( "No function to compile. " "Did you decorate a function with @register_fuction(name)?" From 7630fbc22bfea75004f80827fbc760e92ff9cc37 Mon Sep 17 00:00:00 2001 From: hernan232 Date: Tue, 16 Aug 2022 09:12:17 -0500 Subject: [PATCH 124/221] Correct documentation in BMR table. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 70db2a0c6..5f257c4d5 100644 --- a/README.md +++ b/README.md @@ -842,7 +842,7 @@ lists the available schemes. | Program | Protocol | Dishonest Maj. | Malicious | \# parties | Script | | --- | --- | --- | --- | --- | --- | | `real-bmr-party.x` | MASCOT | Y | Y | 2 or more | `real-bmr.sh` | -| `semi-bmr-party.x` | Semi | Y | Y | 2 or more | `semi-bmr.sh` | +| `semi-bmr-party.x` | Semi | Y | N | 2 or more | `semi-bmr.sh` | | `shamir-bmr-party.x` | Shamir | N | N | 3 or more | `shamir-bmr.sh` | | `mal-shamir-bmr-party.x` | Shamir | N | Y | 3 or more | `mal-shamir-bmr.sh` | | `rep-bmr-party.x` | Replicated | N | N | 3 | `rep-bmr.sh` | From e08a6adb63ea057338f5613645d9d498cb43f2a9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 17 Aug 2022 13:22:04 +1000 Subject: [PATCH 125/221] Fix shuffling in emulation. --- Processor/Instruction.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 7763c8377..fbab7aa9e 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -509,9 +509,12 @@ bool Instruction::get_offline_data_usage(DataPositions& usage) case USE_INP: if (r[0] >= N_DATA_FIELD_TYPE) throw invalid_program(); - if ((unsigned)r[1] >= usage.inputs.size()) - throw Processor_Error("Player number too high"); - usage.inputs[r[1]][r[0]] = n; + if (usage.inputs.size() != 1) + { + if ((unsigned) r[1] >= usage.inputs.size()) + throw Processor_Error("Player number too high"); + usage.inputs[r[1]][r[0]] = n; + } return int(n) >= 0; case USE_EDABIT: usage.edabits[{r[0], r[1]}] = n; From 6a424539c93f5489a6d09360f0092224552d94d8 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 25 Aug 2022 13:20:46 +1000 Subject: [PATCH 126/221] SoftSpokenOT. --- .gitignore | 4 + .gitmodules | 12 +- BMR/Party.cpp | 2 + BMR/RealGarbleWire.h | 2 - BMR/RealGarbleWire.hpp | 2 +- BMR/RealProgramParty.hpp | 4 +- BMR/Register.h | 22 +--- BMR/Register.hpp | 12 +- BMR/Register_inline.h | 6 +- CHANGELOG.md | 17 ++- CONFIG | 18 +-- Compiler/GC/instructions.py | 16 ++- Compiler/GC/types.py | 83 ++++++++++--- Compiler/comparison.py | 21 ++-- Compiler/compilerLib.py | 11 +- Compiler/floatingpoint.py | 32 ++--- Compiler/instructions.py | 69 ++++++----- Compiler/instructions_base.py | 10 +- Compiler/library.py | 29 +++-- Compiler/ml.py | 13 ++- Compiler/mpc_math.py | 4 +- Compiler/program.py | 18 ++- Compiler/types.py | 131 ++++++++++++++------- Compiler/util.py | 3 + ECDSA/P256Element.h | 2 +- ECDSA/fake-spdz-ecdsa-party.cpp | 8 +- ECDSA/ot-ecdsa-party.hpp | 9 +- FHE/Ciphertext.cpp | 2 +- FHE/NTL-Subs.cpp | 6 +- FHE/PPData.cpp | 3 +- FHEOffline/PairwiseMachine.cpp | 17 ++- FHEOffline/PairwiseMachine.h | 25 ++-- FHEOffline/SimpleGenerator.cpp | 2 +- FHEOffline/SimpleGenerator.h | 6 +- FHEOffline/SimpleMachine.cpp | 14 +++ FHEOffline/SimpleMachine.h | 21 +++- GC/AtlasShare.h | 5 - GC/CcdShare.h | 5 - GC/FakeSecret.cpp | 3 +- GC/MaliciousCcdShare.h | 5 - GC/Processor.h | 2 +- GC/Secret.hpp | 2 +- GC/Secret_inline.h | 8 +- GC/ShareParty.hpp | 22 ++-- GC/ShareSecret.hpp | 9 +- GC/ShareThread.h | 1 + GC/ShareThread.hpp | 6 + GC/ThreadMaster.hpp | 10 +- GC/TinierShare.h | 22 +++- GC/TinyMC.h | 10 ++ GC/TinyShare.h | 5 - HOSTS.example | 5 - Machines/OTMachine.cpp | 2 +- Machines/Tinier.cpp | 12 +- Machines/TripleMachine.cpp | 20 +++- Machines/mama-party.cpp | 32 +++-- Machines/spdz2k-party.cpp | 8 +- Machines/tinier-party.cpp | 4 +- Makefile | 94 ++++++++++----- Math/Square.h | 2 + Math/Square.hpp | 9 ++ Math/Zp_Data.cpp | 4 + Math/Zp_Data.h | 2 + Math/bigint.h | 2 +- Networking/CryptoPlayer.cpp | 30 ++++- Networking/CryptoPlayer.h | 5 + Networking/Player.cpp | 119 ++++++++----------- Networking/Player.h | 70 +++++------ Networking/PlayerBuffer.h | 23 ++++ Networking/PlayerCtSocket.h | 169 +++++++++++++++++++++++++++ Networking/Receiver.h | 5 + Networking/Sender.h | 5 + OT/BaseOT.cpp | 125 ++++++++------------ OT/BaseOT.h | 11 +- OT/BitMatrix.h | 3 + OT/BitMatrix.hpp | 7 +- OT/MamaRectangle.h | 5 + OT/NPartyTripleGenerator.h | 5 +- OT/NPartyTripleGenerator.hpp | 38 +++++- OT/OTExtension.cpp | 9 ++ OT/OTExtension.h | 2 +- OT/OTExtensionWithMatrix.cpp | 137 +++++++++++++++++++++- OT/OTExtensionWithMatrix.h | 28 ++++- OT/OTMultiplier.h | 16 +-- OT/OTMultiplier.hpp | 75 ++++++++++-- OT/OTTripleSetup.h | 25 +++- OT/Rectangle.h | 2 + OT/Rectangle.hpp | 7 ++ OT/TripleMachine.h | 7 +- Processor/BaseMachine.cpp | 17 ++- Processor/BaseMachine.h | 11 +- Processor/Instruction.hpp | 15 +-- Processor/Machine.hpp | 23 ++-- Processor/NoFilePrep.h | 2 +- Processor/OfflineMachine.hpp | 16 ++- Processor/Online-Thread.hpp | 3 + Processor/OnlineMachine.h | 5 + Processor/OnlineMachine.hpp | 10 -- Processor/OnlineOptions.cpp | 9 +- Processor/Processor.h | 4 +- Processor/Processor.hpp | 27 +++-- Programs/Source/mnist_full_C.mpc | 1 + Programs/Source/test_args.mpc | 2 +- Programs/Source/test_gc.mpc | 8 +- Protocols/ChaiGearPrep.h | 2 +- Protocols/ChaiGearPrep.hpp | 4 +- Protocols/ChaiGearShare.h | 1 + Protocols/CowGearShare.h | 1 + Protocols/FakeInput.h | 2 +- Protocols/LowGearKeyGen.hpp | 4 + Protocols/MAC_Check.h | 11 ++ Protocols/MAC_Check.hpp | 34 +++++- Protocols/MAC_Check_Base.h | 3 + Protocols/MaliciousRepPrep.hpp | 1 + Protocols/MamaPrep.h | 3 +- Protocols/MamaPrep.hpp | 4 +- Protocols/MamaShare.h | 8 +- Protocols/NoShare.h | 17 ++- Protocols/ProtocolSetup.h | 12 ++ Protocols/ReplicatedPrep.hpp | 8 +- Protocols/SecureShuffle.hpp | 10 +- Protocols/SemiPrep.h | 2 + Protocols/SemiPrep.hpp | 15 +++ Protocols/Share.h | 6 +- Protocols/ShareInterface.h | 2 +- Protocols/Spdz2kPrep.h | 2 +- Protocols/Spdz2kShare.h | 4 +- Protocols/fake-stuff.h | 2 + Protocols/fake-stuff.hpp | 19 ++- README.md | 168 ++++++++++++++------------ Scripts/build.sh | 12 +- Scripts/test_tutorial.sh | 2 +- Tools/Coordinator.cpp | 78 +++++++++++++ Tools/Coordinator.h | 43 +++++++ Tools/NetworkOptions.h | 2 +- Tools/PointerVector.h | 2 +- Tools/Subroutines.cpp | 4 +- Tools/Subroutines.h | 9 +- Tools/Waksman.h | 9 +- Tools/benchmarking.cpp | 6 +- Tools/benchmarking.h | 2 +- Tools/int.h | 5 + Tools/intrinsics.h | 1 + Tools/octetStream.h | 29 ++++- Tools/random.cpp | 35 +----- Tools/random.h | 66 +++++++++-- Tools/time-func.cpp | 3 +- Utils/Fake-Offline.cpp | 35 +++--- Utils/binary-example.cpp | 2 +- Utils/pairwise-offline.cpp | 2 +- Yao/YaoEvalWire.cpp | 2 +- Yao/YaoEvalWire.h | 2 - Yao/YaoGarbleWire.cpp | 2 +- Yao/YaoGarbleWire.h | 2 - azure-pipelines.yml | 11 +- SimpleOT => deps/SimpleOT | 0 deps/SimplestOT_C | 1 + deps/libOTe | 1 + mpir => deps/mpir | 0 simde => deps/simde | 0 doc/Doxyfile | 2 +- doc/add-protocol.rst | 17 +++ doc/compilation.rst | 182 +++++++++++++++++++++++++++++ doc/conf.py | 10 +- doc/gen-readme.sh | 4 + doc/index.rst | 194 +------------------------------ doc/instructions.rst | 5 +- doc/low-level.rst | 18 ++- doc/machine-learning.rst | 4 +- doc/requirements.txt | 3 +- doc/troubleshooting.rst | 2 +- 171 files changed, 2179 insertions(+), 1023 deletions(-) delete mode 100644 HOSTS.example create mode 100644 Networking/PlayerBuffer.h create mode 100644 Networking/PlayerCtSocket.h create mode 100644 Tools/Coordinator.cpp create mode 100644 Tools/Coordinator.h rename SimpleOT => deps/SimpleOT (100%) create mode 160000 deps/SimplestOT_C create mode 160000 deps/libOTe rename mpir => deps/mpir (100%) rename simde => deps/simde (100%) create mode 100644 doc/compilation.rst create mode 100755 doc/gen-readme.sh diff --git a/.gitignore b/.gitignore index 9a4dd72e2..0d770b1ec 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,7 @@ _build/ # environment .env + +# temp doc files +doc/readme.md +doc/xml diff --git a/.gitmodules b/.gitmodules index 32dca28be..7dea81d36 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,18 @@ [submodule "SimpleOT"] - path = SimpleOT + path = deps/SimpleOT url = https://github.com/mkskeller/SimpleOT [submodule "mpir"] - path = mpir + path = deps/mpir url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion [submodule "simde"] - path = simde + path = deps/simde url = https://github.com/simd-everywhere/simde +[submodule "deps/libOTe"] + path = deps/libOTe + url = https://github.com/mkskeller/softspoken-implementation +[submodule "deps/SimplestOT_C"] + path = deps/SimplestOT_C + url = https://github.com/mkskeller/SimplestOT_C diff --git a/BMR/Party.cpp b/BMR/Party.cpp index beddd64cf..0fe11a0f1 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) : } cout << "Compiler: " << prev << endl; P = new PlainPlayer(N, 0); + Share::MAC_Check::setup(*P); if (argc > 4) threshold = atoi(argv[4]); cout << "Threshold for multi-threaded evaluation: " << threshold << endl; @@ -280,6 +281,7 @@ FakeProgramParty::~FakeProgramParty() cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes() << " GB" << endl; #endif + Share::MAC_Check::teardown(); } void FakeProgramParty::_compute_prfs_outputs(Key* keys) diff --git a/BMR/RealGarbleWire.h b/BMR/RealGarbleWire.h index 9fa2dc521..115d0bcaa 100644 --- a/BMR/RealGarbleWire.h +++ b/BMR/RealGarbleWire.h @@ -48,8 +48,6 @@ class RealGarbleWire : public PRFRegister static void inputbvec(GC::Processor>& processor, ProcessorBase& input_processor, const vector& args); - RealGarbleWire(const Register& reg) : PRFRegister(reg) {} - void garble(PRFOutputs& prf_output, const RealGarbleWire& left, const RealGarbleWire& right); diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 760a20b89..c9e31fc65 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -110,7 +110,7 @@ void RealGarbleWire::inputbvec( { GarbleInputter inputter; processor.inputbvec(inputter, input_processor, args, - inputter.party.P->my_num()); + *inputter.party.P); } template diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 64efc5506..70208ec50 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -97,8 +97,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : if (online_opts.live_prep) { mac_key.randomize(prng); - if (T::needs_ot) - BaseMachine::s().ot_setups.push_back({*P, true}); prep = new typename T::LivePrep(0, usage); } else @@ -107,6 +105,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : prep = new Sub_Data_Files(N, prep_dir, usage); } + T::MAC_Check::setup(*P); MC = new typename T::MAC_Check(mac_key); garble_processor.reset(program); @@ -219,6 +218,7 @@ RealProgramParty::~RealProgramParty() delete garble_inputter; delete garble_protocol; cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; + T::MAC_Check::teardown(); } template diff --git a/BMR/Register.h b/BMR/Register.h index f348f7b7e..6a15a720c 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -152,7 +152,7 @@ class Register { * for pipelining matters. */ - Register(int n_parties); + Register(); void init(int n_parties); void init(int rfd, int n_parties); @@ -278,10 +278,6 @@ class ProgramRegister : public Phase, public Register static int threshold(int) { throw not_implemented(); } - static Register new_reg(); - static Register tmp_reg() { return new_reg(); } - static Register and_reg() { return new_reg(); } - template static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } @@ -306,8 +302,6 @@ class ProgramRegister : public Phase, public Register void other_input(Input&, int) {} char get_output() { return 0; } - - ProgramRegister(const Register& reg) : Register(reg) {} }; class PRFRegister : public ProgramRegister @@ -319,8 +313,6 @@ class PRFRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - PRFRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const PRFRegister& left, const PRFRegister& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char input = -1); @@ -396,8 +388,6 @@ class EvalRegister : public ProgramRegister static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>& proc); - EvalRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const ProgramRegister& left, const ProgramRegister& right, Function func); void XOR(const Register& left, const Register& right); @@ -427,8 +417,6 @@ class GarbleRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - GarbleRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const Register& left, const Register& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char value = -1); @@ -452,8 +440,6 @@ class RandomRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - RandomRegister(const Register& reg) : ProgramRegister(reg) {} - void randomize(); void op(const Register& left, const Register& right, Function func); @@ -469,12 +455,6 @@ class RandomRegister : public ProgramRegister }; -inline Register::Register(int n_parties) : - garbled_entry(n_parties), external(NO_SIGNAL), - mask(NO_SIGNAL), keys(n_parties) -{ -} - inline void KeyVector::operator=(const KeyVector& other) { resize(other.size()); diff --git a/BMR/Register.hpp b/BMR/Register.hpp index bd214a858..617906945 100644 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor, const vector& args) { NoOpInputter inputter; - int my_num = -1; - try - { - my_num = ProgramParty::s().P->my_num(); - } - catch (exception&) - { - } - processor.inputbvec(inputter, input_processor, args, my_num); + processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P); } template @@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor, { EvalInputter inputter; processor.inputbvec(inputter, input_processor, args, - ProgramParty::s().P->my_num()); + *ProgramParty::s().P); } template diff --git a/BMR/Register_inline.h b/BMR/Register_inline.h index 6a275da64..7694c464d 100644 --- a/BMR/Register_inline.h +++ b/BMR/Register_inline.h @@ -9,10 +9,10 @@ #include "CommonParty.h" #include "Party.h" - -inline Register ProgramRegister::new_reg() +inline Register::Register() : + garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL), + mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties()) { - return Register(CommonParty::s().get_n_parties()); } #endif /* BMR_REGISTER_INLINE_H_ */ diff --git a/CHANGELOG.md b/CHANGELOG.md index ac6435805..e8e015348 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.3 (Aug 25, 2022) + +- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate +- Fix security bug in MAC check when using multithreading +- Fix security bug to prevent selective failure attack by checking earlier +- Fix security bug in Mama: insufficient sacrifice. +- Inverse permutation (@Quitlox) +- Easier direct compilation (@eriktaubeneck) +- Generally allow element-vector operations +- Increase maximum register size to 2^54 +- Client example in Python +- Uniform base OTs across platforms +- Multithreaded base OT computation +- Faster random bit generation in two-player Semi(2k) + ## 0.3.2 (May 27, 2022) - Secure shuffling @@ -7,7 +22,7 @@ The changelog explains changes pulled through from the private development repos - Documented BGV encryption interface - Optimized matrix multiplication in dealer protocol - Fixed security bug in homomorphic encryption parameter generation -- Fixed Security bug in Temi matrix multiplication +- Fixed security bug in Temi matrix multiplication ## 0.3.1 (Apr 19, 2022) diff --git a/CONFIG b/CONFIG index cef15e0b4..fb9db2009 100644 --- a/CONFIG +++ b/CONFIG @@ -31,24 +31,21 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx ARCH = -march=native MACHINE := $(shell uname -m) +ARM := $(shell uname -m | grep x86; echo $$?) OS := $(shell uname -s) ifeq ($(MACHINE), x86_64) -# set this to 0 to avoid using AVX for OT ifeq ($(OS), Linux) -CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?) -ifeq ($(CHECK_AVX), 0) AVX_OT = 1 else AVX_OT = 0 endif else -AVX_OT = 1 -endif -else ARCH = AVX_OT = 0 endif +USE_KOS = 0 + # allow to set compiler in CONFIG.mine CXX = g++ @@ -87,7 +84,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) @@ -98,3 +95,10 @@ ifeq ($(USE_NTL),1) CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy endif endif + +ifeq ($(USE_KOS),1) +CFLAGS += -DUSE_KOS +else +CFLAGS += -std=c++17 +LDLIBS += -llibOTe -lcryptoTools +endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index e53b71879..2b5ec46ad 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -342,7 +342,8 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): code = opcodes['STMCB'] arg_format = ['cb','long'] -class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit memory cell with run-time address to secret bit register. @@ -351,8 +352,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMSBI'] arg_format = ['sbw','ci'] + direct = staticmethod(ldmsb) -class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit register to secret bit memory cell with run-time address. @@ -361,8 +364,10 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMSBI'] arg_format = ['sb','ci'] + direct = staticmethod(stmsb) -class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit memory cell with run-time address to clear bit register. @@ -371,8 +376,10 @@ class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMCBI'] arg_format = ['cbw','ci'] + direct = staticmethod(ldmcb) -class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit register to clear bit memory cell with run-time address. @@ -381,6 +388,7 @@ class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMCBI'] arg_format = ['cb','ci'] + direct = staticmethod(stmcb) class ldmsdi(base.ReadMemoryInstruction): code = opcodes['LDMSDI'] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 5530432bb..396769e03 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -198,6 +198,8 @@ def __and__(self, other): return 0 elif self.is_long_one(other): return self + elif isinstance(other, _vec): + return other & other.from_vec([self]) else: return self._and(other) @read_mem_value @@ -241,6 +243,13 @@ def zero_if_not(self, condition): return self * condition else: return self * cbit.conv(condition) + def expand(self, length): + if self.n in (length, None): + return self + elif self.n == 1: + return self.get_type(length).bit_compose([self] * length) + else: + raise CompilerError('cannot expand from %s to %s' % (self.n, length)) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -295,8 +304,15 @@ def clear_op(self, other, c_inst, ci_inst, op): return op(self, cbits(other)) __add__ = lambda self, other: \ self.clear_op(other, inst.addcb, inst.addcbi, operator.add) - __sub__ = lambda self, other: \ - self.clear_op(-other, inst.addcb, inst.addcbi, operator.add) + def __sub__(self, other): + try: + return self + -other + except: + return type(self)(regint(self) - regint(other)) + def __rsub__(self, other): + return type(self)(other - regint(self)) + def __neg__(self): + return type(self)(-regint(self)) def _xor(self, other): if isinstance(other, (sbits, sbitvec)): return NotImplemented @@ -589,7 +605,15 @@ def trans(cls, rows): rows = list(rows) if len(rows) == 1 and rows[0].n <= rows[0].unit: return rows[0].bit_decompose() - n_columns = rows[0].n + for row in rows: + try: + n_columns = row.n + break + except: + pass + for i in range(len(rows)): + if util.is_zero(rows[i]): + rows[i] = cls.get_type(n_columns)(0) for row in rows: assert(row.n == n_columns) if n_columns == 1 and len(rows) <= cls.unit: @@ -613,7 +637,7 @@ def bit_adder(*args, **kwargs): def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) -class sbitvec(_vec): +class sbitvec(_vec, _bit): """ Vector of registers of secret bits, effectively a matrix of secret bits. This facilitates parallel arithmetic operations in binary circuits. Container types are not supported, use :py:obj:`sbitvec.get_type` for that. @@ -656,6 +680,7 @@ class sbitvec(_vec): [1, 0, 1] """ bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) + is_clear = False @classmethod def get_type(cls, n): """ Create type for fixed-length vector of registers of secret bits. @@ -691,10 +716,11 @@ def from_vec(cls, vector): res.v = _complement_two_extend(list(vector), n)[:n] return res def __init__(self, other=None, size=None): - assert size in (None, 1) if other is not None: if util.is_constant(other): - self.v = [sbit((other >> i) & 1) for i in range(n)] + t = sbits.get_type(size or 1) + self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1)) + for i in range(n)] elif isinstance(other, _vec): self.v = self.bit_extend(other.v, n) elif isinstance(other, (list, tuple)): @@ -702,6 +728,7 @@ def __init__(self, other=None, size=None): else: self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n + assert size is None or size == self.v[0].n @classmethod def load_mem(cls, address, size=None): if size not in (None, 1): @@ -733,8 +760,9 @@ def store_in_mem(self, address): def reveal(self): return util.untuplify([x.reveal() for x in self.elements()]) @classmethod - def two_power(cls, nn): - return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) + def two_power(cls, nn, size=1): + return cls.from_vec( + [0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1)) def coerce(self, other): if util.is_constant(other): return self.from_vec(util.bit_decompose(other, n)) @@ -818,16 +846,14 @@ def coerce(self, other): return other def __xor__(self, other): other = self.coerce(other) - return self.from_vec(x ^ y for x, y in zip(self.v, other)) + return self.from_vec(x ^ y for x, y in zip(*self.expand(other))) def __and__(self, other): - return self.from_vec(x & y for x, y in zip(self.v, other.v)) + return self.from_vec(x & y for x, y in zip(*self.expand(other))) + def __invert__(self): + return self.from_vec(~x for x in self.v) def if_else(self, x, y): assert(len(self.v) == 1) - try: - return self.from_vec(util.if_else(self.v[0], a, b) \ - for a, b in zip(x, y)) - except: - return util.if_else(self.v[0], x, y) + return util.if_else(self.v[0], x, y) def __iter__(self): return iter(self.v) def __len__(self): @@ -890,6 +916,24 @@ def tree_reduce(self, function): elements = red.elements() elements += odd return self.from_vec(sbitvec(elements).v) + @classmethod + def comp_result(cls, x): + return cls.get_type(1).from_vec([x]) + def expand(self, other, expand=True): + m = 1 + for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []): + try: + m = max(m, x.n) + except: + pass + res = [] + for y in self, other: + if isinstance(y, int): + res.append([x * sbits.get_type(m)().long_one() + for x in util.bit_decompose(y, len(self.v))]) + else: + res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v]) + return res class bit(object): n = 1 @@ -1139,7 +1183,7 @@ def pow2(self, k): :param k: bit length of input """ return _sbitintbase.pow2(self, k) -class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): +class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ Vector of signed integers for parallel binary computation:: @@ -1176,7 +1220,8 @@ def __add__(self, other): return self other = self.coerce(other) assert(len(self.v) == len(other.v)) - v = sbitint.bit_adder(self.v, other.v) + a, b = self.expand(other) + v = sbitint.bit_adder(a, b) return self.from_vec(v) __radd__ = __add__ def __mul__(self, other): @@ -1184,7 +1229,7 @@ def __mul__(self, other): return self.from_vec(other * x for x in self.v) elif isinstance(other, sbitfixvec): return NotImplemented - other_bits = util.bit_decompose(other) + _, other_bits = self.expand(other, False) m = float('inf') for x in itertools.chain(self.v, other_bits): try: @@ -1228,6 +1273,8 @@ class cbitfix(object): store_in_mem = lambda self, *args: self.v.store_in_mem(*args) @classmethod def _new(cls, value): + if isinstance(value, list): + return [cls._new(x) for x in value] res = cls() if cls.k < value.unit: bits = value.bit_decompose(cls.k) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 84bdd22b6..1a139ef6d 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -87,15 +87,14 @@ def LtzRing(a, k): carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] return sint.conv(msb) - return - elif program.options.ring: + else: from . import floatingpoint require_ring_size(k, 'comparison') m = k - 1 shift = int(program.options.ring) - k r_prime, r_bin = MaskingBitsInRing(k) tmp = a - r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) @@ -190,7 +189,7 @@ def TruncLeakyInRing(a, k, m, signed): r = sint.bit_compose(r_bits) if signed: a += (1 << (k - 1)) - shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal() + shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False) masked = shifted >> n_shift u = sint() BitLTL(u, masked, r_bits[:n_bits], 0) @@ -231,7 +230,7 @@ def Mod2mRing(a_prime, a, k, m, signed): shift = int(program.options.ring) - m r_prime, r_bin = MaskingBitsInRing(m, True) tmp = a + r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift u = sint() BitLTL(u, c_prime, r_bin[:m], 0) res = (u << m) + c_prime - r_prime @@ -261,7 +260,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) modc(c_prime, c, c2m) if const_rounds: BitLTC1(u, c_prime, r, kappa) @@ -510,7 +509,7 @@ def PreMulC_with_inverses_and_vectors(p, a): movs(w[0], r[0]) movs(a_vec[0], a[0]) vmuls(k, t[0], w, a_vec) - vasm_open(k, m, t[0]) + vasm_open(k, True, m, t[0]) PreMulC_end(p, a, c, m, z) def PreMulC_with_inverses(p, a): @@ -538,7 +537,7 @@ def PreMulC_with_inverses(p, a): w[1][0] = r[0][0] for i in range(k): muls(t[0][i], w[1][i], a[i]) - asm_open(m[i], t[0][i]) + asm_open(True, m[i], t[0][i]) PreMulC_end(p, a, c, m, z) def PreMulC_without_inverses(p, a): @@ -563,7 +562,7 @@ def PreMulC_without_inverses(p, a): #adds(tt[0][i], t[0][i], a[i]) #subs(tt[1][i], tt[0][i], a[i]) #startopen(tt[1][i]) - asm_open(u[i], t[0][i]) + asm_open(True, u[i], t[0][i]) for i in range(k-1): muls(v[i], r[i+1], s[i]) w[0] = r[0] @@ -579,7 +578,7 @@ def PreMulC_without_inverses(p, a): mulm(z[i], s[i], u_inv[i]) for i in range(k): muls(t[1][i], w[i], a[i]) - asm_open(m[i], t[1][i]) + asm_open(True, m[i], t[1][i]) PreMulC_end(p, a, c, m, z) def PreMulC_end(p, a, c, m, z): @@ -646,7 +645,7 @@ def Mod2(a_0, a, k, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) from . import floatingpoint c_0 = floatingpoint.bits(c, 1)[0] mulci(tc, c_0, 2) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index eb800ba48..4a4706ff6 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -181,7 +181,8 @@ def build_option_parser(self): action="store_true", dest="invperm", help="speedup inverse permutation (only use in two-party, " - "semi-honest environment)") + "semi-honest environment)" + ) parser.add_option( "-C", "--CISC", @@ -244,11 +245,9 @@ def build_vars(self): self.VARS[op.__name__] = op # add open and input separately due to name conflict - self.VARS["open"] = instructions.asm_open self.VARS["vopen"] = instructions.vasm_open self.VARS["gopen"] = instructions.gasm_open self.VARS["vgopen"] = instructions.vgasm_open - self.VARS["input"] = instructions.asm_input self.VARS["ginput"] = instructions.gasm_input self.VARS["comparison"] = comparison @@ -268,7 +267,6 @@ def build_vars(self): "sgf2nuint", "sgf2nuint32", "sgf2nfloat", - "sfloat", "cfloat", "squant", ]: @@ -276,6 +274,9 @@ def build_vars(self): def prep_compile(self, name=None): self.parse_args() + if len(self.args) < 1 and name is None: + self.parser.print_help() + exit(1) self.build_program(name=name) self.build_vars() @@ -372,7 +373,7 @@ def compile_func(self): ) self.prep_compile(self.compile_name) print( - f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}" + "Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__) ) self.compile_function() self.finalize_compile() diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index d3d3f8c50..94a47f1bf 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -28,7 +28,7 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - if program.Program.prog.use_edabit: + if program.Program.prog.use_edabit(): r_prime, r = types.sint.get_edabit(k) elif program.Program.prog.use_dabit: rr, r = zip(*(types.sint.get_dabit() for i in range(k))) @@ -36,7 +36,7 @@ def maskRing(a, k): else: r = [types.sint.get_random_bit() for i in range(k)] r_prime = types.sint.bit_compose(r) - c = ((a + r_prime) << shift).reveal() >> shift + c = ((a + r_prime) << shift).reveal(False) >> shift return c, r def maskField(a, k, kappa): @@ -47,7 +47,7 @@ def maskField(a, k, kappa): comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) # always signed due to usage in equality testing a += two_power(k) - asm_open(c, a + two_power(k) * r_dprime + r_prime) + asm_open(True, c, a + two_power(k) * r_dprime + r_prime) return c, r @instructions_base.ret_cisc @@ -233,7 +233,7 @@ def Inv(a): ldi(one, 1) inverse(t[0], t[1]) s = t[0]*a - asm_open(c[0], s) + asm_open(True, c[0], s) # avoid division by zero for benchmarking divc(c[1], one, c[0]) #divc(c[1], c[0], one) @@ -281,7 +281,7 @@ def BitDecRingRaw(a, k, m): else: r_bits = [types.sint.get_random_bit() for i in range(m)] r = types.sint.bit_compose(r_bits) - shifted = ((a - r) << n_shift).reveal() + shifted = ((a - r) << n_shift).reveal(False) masked = shifted >> n_shift bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) return bits @@ -299,7 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): r = [types.sint() for i in range(m)] comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) pow2 = two_power(k + kappa) - asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) + asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) instructions_base.reset_global_vector_size() return res @@ -341,10 +341,10 @@ def B2U_from_Pow2(pow2a, l, kappa): if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l assert n_shift > 0 - c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift + c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(t, kappa) - asm_open(c, pow2a + two_power(l) * t + + asm_open(True, c, pow2a + two_power(l) * t + sum(two_power(i) * r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(r_bits[0].bit_decompose_clear(c, l)) @@ -386,11 +386,11 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): r_dprime += t1 - t2 if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l - c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift + c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(rk, kappa) r_dprime += two_power(l) * rk - asm_open(c, a + r_dprime + r_prime) + asm_open(True, c, a + r_dprime + r_prime) for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) @@ -416,7 +416,7 @@ def TruncInRing(to_shift, l, pow2m): rev *= pow2m r_bits = [types.sint.get_random_bit() for i in range(l)] r = types.sint.bit_compose(r_bits) - shifted = (rev - (r << n_shift)).reveal() + shifted = (rev - (r << n_shift)).reveal(False) masked = shifted >> n_shift bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l)) return types.sint.bit_compose(reversed(bits)) @@ -457,7 +457,7 @@ def Int2FL(a, gamma, l, kappa=None): v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False) else: v = 2**(l-gamma+1) * t - p = (p + gamma - 1 - l) * (1 -z) + p = (p + gamma - 1 - l) * z.bit_not() return v, p, z, s def FLRound(x, mode): @@ -530,7 +530,7 @@ def TruncPrRing(a, k, m, signed=True): msb = r_bits[-1] n_shift = n_ring - (k + 1) tmp = a + r - masked = (tmp << n_shift).reveal() + masked = (tmp << n_shift).reveal(False) shifted = (masked << 1 >> (n_shift + m + 1)) overflow = msb.bit_xor(masked >> (n_ring - 1)) res = shifted - upper + \ @@ -551,7 +551,7 @@ def TruncPrField(a, k, m, kappa=None): k, m, kappa, use_dabit=False) two_to_m = two_power(m) r = two_to_m * r_dprime + r_prime - c = (b + r).reveal() + c = (b + r).reveal(False) c_prime = c % two_to_m a_prime = c_prime - r_prime d = (a - a_prime) / two_to_m @@ -667,14 +667,14 @@ def get_bits_loop(): def _(): for i in range(bit_length): tbits[j][i].link(sint.get_random_bit()) - c = regint(BITLT(tbits[j], pbits, bit_length).reveal()) + c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False)) done[j].link(c) return (sum(done) != a.size) for j in range(a.size): for i in range(bit_length): movs(bbits[i][j], tbits[j][i]) b = sint.bit_compose(bbits) - c = (a-b).reveal() + c = (a-b).reveal(False) cmodp = c t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 058b6ff4f..91809ba44 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -387,6 +387,14 @@ class use(base.Instruction): code = base.opcodes['USE'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), + find_in_dict(data_types, args[1].i)): + args[2].i} + class use_inp(base.Instruction): """ Input usage. Necessary to avoid reusage while using preprocessing from files. @@ -398,6 +406,13 @@ class use_inp(base.Instruction): code = base.opcodes['USE_INP'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), 'input', args[1].i): + args[2].i} + class use_edabit(base.Instruction): """ edaBit usage. Necessary to avoid reusage while using preprocessing from files. Also used to multithreading for expensive @@ -410,6 +425,10 @@ class use_edabit(base.Instruction): code = base.opcodes['USE_EDABIT'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i} + class use_matmul(base.Instruction): """ Matrix multiplication usage. Used for multithreading of preprocessing. @@ -471,6 +490,11 @@ class use_prep(base.Instruction): code = base.opcodes['USE_PREP'] arg_format = ['str','int'] + @classmethod + def get_usage(cls, args): + return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp', + args[0].str): args[1].i} + class nplayers(base.Instruction): """ Store number of players in clear integer register. @@ -783,30 +807,6 @@ def has_var_args(self): return True -### -### Special GF(2) arithmetic instructions -### - -@base.vectorize -class gmulbitc(base.MulBase): - r""" Clear GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITC'] - arg_format = ['cgw','cg','cg'] - - def is_gf2n(self): - return True - -@base.vectorize -class gmulbitm(base.MulBase): - r""" Secret GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITM'] - arg_format = ['sgw','sg','cg'] - - def is_gf2n(self): - return True - ### ### Arithmetic with immediate values ### @@ -1707,6 +1707,7 @@ class writesockets(base.IOInstruction): from registers into a socket for a specified client id. If the protocol uses MACs, the client should be different for every party. + :param: number of arguments to follow :param: client id (regint) :param: message type (must be 0) :param: vector size (int) @@ -2162,14 +2163,19 @@ class gconvgf2n(base.Instruction): class asm_open(base.VarArgsInstruction): """ Reveal secret registers (vectors) to clear registers (vectors). - :param: number of argument to follow (multiple of two) + :param: number of argument to follow (odd number) + :param: check after opening (0/1) :param: destination (cint) :param: source (sint) :param: (repeat the last two)... """ __slots__ = [] code = base.opcodes['OPEN'] - arg_format = tools.cycle(['cw','s']) + arg_format = tools.chain(['int'], tools.cycle(['cw','s'])) + + def merge(self, other): + self.args[0] |= other.args[0] + self.args += other.args[1:] @base.gf2n @base.vectorize @@ -2415,12 +2421,17 @@ class shuffle_base(base.DataInstruction): def logn(n): return int(math.ceil(math.log(n, 2))) + @classmethod + def n_swaps(cls, n): + logn = cls.logn(n) + return logn * 2 ** logn - 2 ** logn + 1 + def add_gen_usage(self, req_node, n): # hack for unknown usage req_node.increment(('bit', 'inverse'), float('inf')) # minimal usage with two relevant parties logn = self.logn(n) - n_switches = logn * 2 ** logn + n_switches = self.n_swaps(n) for i in range(self.n_relevant_parties): req_node.increment((self.field_type, 'input', i), n_switches) # multiplications for bit check @@ -2430,7 +2441,7 @@ def add_gen_usage(self, req_node, n): def add_apply_usage(self, req_node, n, record_size): req_node.increment(('bit', 'inverse'), float('inf')) logn = self.logn(n) - n_switches = logn * 2 ** logn * self.n_relevant_parties + n_switches = self.n_swaps(n) * self.n_relevant_parties if n != 2 ** logn: record_size += 1 req_node.increment((self.field_type, 'triple'), @@ -2548,7 +2559,7 @@ def expand(self): c = [program.curr_block.new_reg('c') for i in range(2)] square(s[0], s[1]) subs(s[2], self.args[1], s[0]) - asm_open(c[0], s[2]) + asm_open(False, c[0], s[2]) mulc(c[1], c[0], c[0]) mulm(s[3], self.args[1], c[0]) adds(s[4], s[3], s[3]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index f7aa48f9b..fb60d908b 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -542,7 +542,7 @@ def add_usage(self, *args): def get_bytes(self): assert len(self.kwargs) < 2 - res = int_to_bytes(opcodes['CISC']) + res = LongArgFormat.encode(opcodes['CISC']) res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1) name = self.function.__name__ String.check(name) @@ -720,7 +720,7 @@ def __str__(self): class LongArgFormat(IntArgFormat): @classmethod def encode(cls, arg): - return struct.pack('>Q', arg) + return list(struct.pack('>Q', arg)) def __init__(self, f): self.i = struct.unpack('>Q', f.read(8))[0] @@ -741,6 +741,8 @@ def check(cls, arg): class PlayerNoAF(IntArgFormat): @classmethod def check(cls, arg): + if not util.is_constant(arg): + raise CompilerError('Player number must be known at compile time') super(PlayerNoAF, cls).check(arg) if arg > 256: raise ArgumentError(arg, 'Player number > 256') @@ -823,7 +825,7 @@ def get_code(self, prefix=0): return (prefix << self.code_length) + self.code def get_encoding(self): - enc = int_to_bytes(self.get_code()) + enc = LongArgFormat.encode(self.get_code()) # add the number of registers if instruction flagged as has var args if self.has_var_args(): enc += int_to_bytes(len(self.args)) @@ -958,7 +960,7 @@ def __init__(self, f): except AttributeError: pass read = lambda: struct.unpack('>I', f.read(4))[0] - full_code = read() + full_code = struct.unpack('>Q', f.read(8))[0] code = full_code % (1 << Instruction.code_length) self.size = full_code >> Instruction.code_length self.type = cls.reverse_opcodes[code] diff --git a/Compiler/library.py b/Compiler/library.py index 799f85d29..1da50e9c9 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -243,6 +243,10 @@ def store_in_mem(value, address): try: value.store_in_mem(address) except AttributeError: + if isinstance(value, (list, tuple)): + for i, x in enumerate(value): + store_in_mem(x, address + i) + return # legacy if value.is_clear: if isinstance(address, cint): @@ -261,11 +265,13 @@ def reveal(secret): try: return secret.reveal() except AttributeError: + if secret.is_clear: + return secret if secret.is_gf2n: res = cgf2n() else: res = cint() - instructions.asm_open(res, secret) + instructions.asm_open(True, res, secret) return res @vectorize @@ -883,10 +889,10 @@ def loop_fn(i): def for_range(start, stop=None, step=None): """ Decorator to execute loop bodies consecutively. Arguments work as - in Python :py:func:`range`, but they can by any public + in Python :py:func:`range`, but they can be any public integer. Information has to be passed out via container types such - as :py:class:`~Compiler.types.Array` or declaring registers as - :py:obj:`global`. Note that changing Python data structures such + as :py:class:`~Compiler.types.Array` or using :py:func:`update`. + Note that changing Python data structures such as lists within the loop is not possible, but the compiler cannot warn about this. @@ -901,13 +907,11 @@ def for_range(start, stop=None, step=None): @for_range(n) def _(i): a[i] = i - global x - x += 1 + x.update(x + 1) Note that you cannot overwrite data structures such as - :py:class:`~Compiler.types.Array` in a loop even when using - :py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign` - instead. + :py:class:`~Compiler.types.Array` in a loop. Use + :py:func:`~Compiler.types.Array.assign` instead. """ def decorator(loop_body): range_loop(loop_body, start, stop, step) @@ -1518,6 +1522,11 @@ class State: pass state = State() if callable(condition): condition = condition() + try: + if not condition.is_clear: + raise CompilerError('cannot branch on secret values') + except AttributeError: + pass state.condition = regint.conv(condition) state.start_block = instructions.program.curr_block state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ @@ -1889,7 +1898,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): theta = int(ceil(log(k/3.5) / log(2))) base.set_global_vector_size(b.size) - alpha = b.get_type(2 * k).two_power(2*f) + alpha = b.get_type(2 * k).two_power(2*f, size=b.size) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() diff --git a/Compiler/ml.py b/Compiler/ml.py index 02f0f04ed..173c2eac0 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -148,7 +148,7 @@ def argmax(x): """ Compute index of maximum element. :param x: iterable - :returns: sint + :returns: sint or 0 if :py:obj:`x` has length 1 """ def op(a, b): comp = (a[1] > b[1]) @@ -164,7 +164,7 @@ def softmax(x): return softmax_from_exp(exp_for_softmax(x)[0]) def exp_for_softmax(x): - m = util.max(x) + m = util.max(x) - get_limit(x[0]) + 1 + math.log(len(x), 2) mv = m.expand_to_vector(len(x)) try: x = x.get_vector() @@ -2384,6 +2384,11 @@ def output_weights(self): for layer in self.layers: layer.output_weights() + def summary(self): + sizes = [var.total_size() for var in self.thetas] + print(sizes) + print('Trainable params:', sum(sizes)) + class Adam(Optimizer): """ Adam/AMSgrad optimizer. @@ -2653,9 +2658,7 @@ def trainable_variables(self): return list(self.opt.thetas) def summary(self): - sizes = [var.total_size() for var in self.trainable_variables] - print(sizes) - print('Trainable params:', sum(sizes)) + self.opt.summary() def build(self, input_shape, batch_size=128): data_input_shape = input_shape diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 47253dc43..abdaf1233 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -295,7 +295,6 @@ class my_fix(type(a)): intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k if types.program.use_split(): - assert not zero_output from Compiler.GC.types import sbitvec if types.program.use_split() == 3: x = a.v.split_to_two_summands(a.k) @@ -327,6 +326,7 @@ class my_fix(type(a)): s = sint.conv(bits[-1]) lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) higher_bits = bits[a.f:n_bits] + bits_to_check = bits[n_bits:-1] else: if types.program.use_edabit(): l = sint.get_edabit(a.f, True) @@ -338,7 +338,7 @@ class my_fix(type(a)): r_bits = [sint.get_random_bit() for i in range(a.k)] r = sint.bit_compose(r_bits) lower_r = sint.bit_compose(r_bits[:a.f]) - shifted = ((a.v - r) << n_shift).reveal() + shifted = ((a.v - r) << n_shift).reveal(False) masked_bits = (shifted >> n_shift).bit_decompose(a.k) lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], r_bits[a.f-1::-1]) diff --git a/Compiler/program.py b/Compiler/program.py index d7b57db90..f92ab4971 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -545,7 +545,7 @@ def use_invperm(self, change=None): """ if change is None: if not self._invperm: - self.relevant_opts.add('invperm') + self.relevant_opts.add("invperm") return self._invperm else: self._invperm = change @@ -1276,7 +1276,7 @@ class Register(_no_truth): "can_eliminate", "duplicates", ] - maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 + maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): """Creates a new register. @@ -1382,6 +1382,20 @@ def link(self, other): for dup in self.duplicates: dup.duplicates = self.duplicates + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + if self.program != other.program: + raise CompilerError( + 'cannot update register with one from another thread') + self.link(other) + @property def is_gf2n(self): return ( diff --git a/Compiler/types.py b/Compiler/types.py index d63295c8f..6a150beee 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,8 +127,14 @@ def vectorized_operation(self, *args, **kwargs): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise VectorMismatch('Different vector sizes of operands: %d/%d' - % (self.size, args[0].size)) + if min(args[0].size, self.size) == 1: + size = max(args[0].size, self.size) + self = self.expand_to_vector(size) + args = list(args) + args[0] = args[0].expand_to_vector(size) + else: + raise VectorMismatch('Different vector sizes of operands: %d/%d' + % (self.size, args[0].size)) set_global_vector_size(self.size) try: res = operation(self, *args, **kwargs) @@ -249,8 +255,11 @@ def __mul__(self, other): try: return self.mul(other) except VectorMismatch: - # try reverse multiplication - return NotImplemented + if type(self) != type(other) and 1 in (self.size, other.size): + # try reverse multiplication + return NotImplemented + else: + raise __radd__ = __add__ __rmul__ = __mul__ @@ -1658,6 +1667,8 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') intoutput(player, self) class localint(Tape._no_truth): @@ -2081,12 +2092,15 @@ def __rsub__(self, other): return self.secret_op(other, subs, submr, subsfi, True) __rsub__.__doc__ = __sub__.__doc__ - @vectorize def __truediv__(self, other): """ Secret field division. :param other: any compatible type """ - return self * (self.clear_type(1) / other) + try: + one = self.clear_type(1, size=other.size) + except AttributeError: + one = self.clear_type(1) + return self * (one / other) @vectorize def __rtruediv__(self, other): @@ -2113,12 +2127,12 @@ def secure_shuffle(self, unit_size=1): @set_instruction_type @vectorize - def reveal(self): + def reveal(self, check=True): """ Reveal secret value publicly. :rtype: relevant clear type """ res = self.clear_type() - asm_open(res, self) + asm_open(check, res, self) return res @set_instruction_type @@ -2166,9 +2180,7 @@ class sint(_secret, _int): signed integer in a restricted range, see below. The same holds for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and exponentation (``**``). Modulo only works if the right-hand - operator is a compile-time power of two, and exponentiation only - works if the base is two or if the exponent is a compile-time - integer. + operator is a compile-time power of two. Most non-linear operations require compile-time parameters for bit length and statistical security. They default to the global @@ -2672,7 +2684,7 @@ def trunc_zeros(self, n_zeros, bit_length=None, signed=True): return comparison.TruncZeros(self, bit_length, n_zeros, signed) @staticmethod - def two_power(n): + def two_power(n, size=None): return floatingpoint.two_power(n) def split_to_n_summands(self, length, n): @@ -2690,7 +2702,6 @@ def split_to_two_summands(self, length, get_carry=False): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) - @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -2698,13 +2709,14 @@ def reveal_to(self, player): :returns: :py:class:`personal` """ if not util.is_constant(player): - secret_mask = sint() - player_mask = cint() - inputmaskreg(secret_mask, player_mask, regint.conv(player)) + secret_mask = sint(size=self.size) + player_mask = cint(size=self.size) + inputmaskreg(secret_mask, player_mask, + regint.conv(player).expand_to_vector(self.size)) return personal(player, - (self + secret_mask).reveal() - player_mask) + (self + secret_mask).reveal(False) - player_mask) else: - res = personal(player, self.clear_type()) + res = personal(player, self.clear_type(size=self.size)) privateoutput(self.size, player, res._v, self) return res @@ -2856,6 +2868,10 @@ def __rsub__(self, other): else: return super(sintbit, self).__rsub__(other) + __rand__ = __and__ + __rxor__ = __xor__ + __ror__ = __or__ + class sgf2n(_secret, _gf2n): """ Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A @@ -2873,6 +2889,7 @@ class sgf2n(_secret, _gf2n): instruction_type = 'gf2n' clear_type = cgf2n reg_type = 'sg' + long_one = staticmethod(lambda: 1) @classmethod def get_type(cls, length): @@ -3022,6 +3039,7 @@ class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False + comp_result = staticmethod(lambda x: x) @staticmethod def half_adder(a, b): @@ -3241,12 +3259,16 @@ def wallace_reduction(cls, a, b, c, get_carry=True): del carries[-1] return sums, carries + def expand(self, other): + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + return a, b + def __sub__(self, other): if type(other) == sgf2n: raise CompilerError('Unclear subtraction') - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) from util import bit_not, bit_and, bit_xor + a, b = self.expand(other) n = 1 for x in (a + b): try: @@ -3293,8 +3315,7 @@ def prep_comparison(a, b): a[-1], b[-1] = b[-1], a[-1] def comparison(self, other, const_rounds=False, index=None): - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) + a, b = self.expand(other) self.prep_comparison(a, b) if const_rounds: return self.get_highest_different_bits(a, b, index) @@ -3304,30 +3325,33 @@ def comparison(self, other, const_rounds=False, index=None): def __lt__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 0) + res = util.if_else(not_equal, x, 0) else: - return self.comparison(other, True, 1) + res = self.comparison(other, True, 1) + return self.comp_result(res) def __le__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 1) + res = util.if_else(not_equal, x, x.long_one()) else: - return 1 - self.comparison(other, True, 0) + res = self.comparison(other, True, 0).bit_not() + return self.comp_result(res) def __ge__(self, other): - return 1 - (self < other) + return (self < other).bit_not() def __gt__(self, other): - return 1 - (self <= other) + return (self <= other).bit_not() def __eq__(self, other, bit_length=None, security=None): diff = self ^ other - diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]] - return floatingpoint.KMul(diff_bits) + diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]] + return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y), + diff_bits)) def __ne__(self, other): - return 1 - (self == other) + return (self == other).bit_not() equal = __eq__ @@ -3881,7 +3905,6 @@ def print_plain(self): def output_if(self, cond): cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size)) - @vectorize def binary_output(self, player=None): """ Write double-precision floating-point number to ``Player-Data/Binary-Output-P-``. @@ -3890,7 +3913,11 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') + set_global_vector_size(self.size) floatoutput(player, self.v, cint(-self.f), cint(0), cint(0)) + reset_global_vector_size() class _single(_number, _secret_structure): """ Representation as single integer preserving the order """ @@ -4124,6 +4151,7 @@ def get_vector(self): class _fix(_single): """ Secret fixed point type. """ __slots__ = ['v', 'f', 'k'] + is_clear = False def set_precision(cls, f, k = None): cls.f = f @@ -4349,6 +4377,18 @@ def bit_decompose(self, n_bits=None): """ Bit decomposition. """ return self.v.bit_decompose(n_bits or self.k) + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + assert self.f == other.f + self.v.update(other.v) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` @@ -4737,6 +4777,8 @@ class sfloat(_number, _secret_structure): returning :py:class:`sint`. The other operand can be any of sint/cfix/regint/cint/int/float. + This data type only works with arithmetic computation. + :param v: initialization (sfloat/sfix/float/int/sint/cint/regint) """ __slots__ = ['v', 'p', 'z', 's', 'size'] @@ -4835,6 +4877,9 @@ def get_input_from(cls, player): @vectorize_init @read_mem_value def __init__(self, v, p=None, z=None, s=None, size=None): + if program.options.binary: + raise CompilerError( + 'floating-point operations not supported with binary circuits') self.size = get_global_vector_size() if p is None: if isinstance(v, sfloat): @@ -5227,7 +5272,13 @@ class Array(_vectorizable): def create_from(cls, l): """ Convert Python iterator or vector to array. Basic type will be taken from first element, further elements must to be convertible to - that. """ + that. + + :param l: Python iterable or register vector + :returns: :py:class:`Array` of appropriate type containing the contents + of :py:obj:`l` + + """ if isinstance(l, cls): return l if isinstance(l, _number): @@ -6099,12 +6150,12 @@ def _(i): try: res_matrix[i] = self.value_type.row_matrix_mul( self[i], other, res_params) - except AttributeError: + except (AttributeError, CompilerError): # fallback for binary circuits - @library.for_range(other.sizes[1]) + @library.for_range_opt(other.sizes[1]) def _(j): res_matrix[i][j] = 0 - @library.for_range(self.sizes[1]) + @library.for_range_opt(self.sizes[1]) def _(k): res_matrix[i][j] += self[i][k] * other[k][j] return res_matrix @@ -6223,13 +6274,7 @@ def _(i): res[i] = self.direct_mul_trans(other, indices=indices) def direct_mul_to_matrix(self, other): - """ Matrix multiplication in the virtual machine. - - :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :returns: :py:obj:`Matrix` - - """ + # Obsolete. Use dot(). res = self.value_type.Matrix(self.sizes[0], other.sizes[1]) res.assign_vector(self.direct_mul(other)) return res diff --git a/Compiler/util.py b/Compiler/util.py index 9d84df226..c1bedc27c 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -238,6 +238,9 @@ def mem_size(x): except AttributeError: return 1 +def find_in_dict(d, v): + return list(d.keys())[list(d.values()).index(v)] + class set_by_id(object): def __init__(self, init=[]): self.content = {} diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 4657b5d88..27ea7f75c 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -22,7 +22,7 @@ class P256Element : public ValueInterface EC_POINT* point; public: - typedef void next; + typedef P256Element next; typedef void Square; static const true_type invertible; diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index f0e3257c6..5bef730d5 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -45,12 +45,13 @@ int main(int argc, const char** argv) string prefix = get_prep_sub_dir(PREP_DIR "ECDSA/", 2); read_mac_key(prefix, N, keyp); + pShare::MAC_Check::setup(P); + Share::MAC_Check::setup(P); + DataPositions usage; Sub_Data_Files prep(N, prefix, usage); typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); - BaseMachine machine; - machine.ot_setups.push_back({P, false}); SubProcessor proc(_, MCp, prep, P); pShare sk, __; @@ -60,4 +61,7 @@ int main(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts); + + pShare::MAC_Check::teardown(); + Share::MAC_Check::teardown(); } diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 569aa791f..ebf0aea96 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -92,9 +92,6 @@ void run(int argc, const char** argv) P256Element::init(); P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false); - BaseMachine machine; - machine.ot_setups.push_back({P, true}); - P256Element::Scalar keyp; SeededPRNG G; keyp.randomize(G); @@ -102,6 +99,9 @@ void run(int argc, const char** argv) typedef T pShare; DataPositions usage; + pShare::MAC_Check::setup(P); + T::MAC_Check::setup(P); + OnlineOptions::singleton.batch_size = 1; typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); @@ -137,4 +137,7 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); //check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + + pShare::MAC_Check::teardown(); + T::MAC_Check::teardown(); } diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 00e051318..62cbd5281 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -130,7 +130,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk) assert(p != 0); for (auto& x : r) { - G.get(x, params->p0().numBits() - p.numBits() - 1); + G.get(x, params->p0().numBits() - p.numBits() - 1); x *= p; } tmp.from(r, 0); diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 794e7431d..f3973026e 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -368,7 +368,8 @@ ZZX Cyclotomic(int N) int phi_N(int N) { if (((N - 1) & N) != 0) - throw runtime_error("compile with NTL support"); + throw runtime_error( + "compile with NTL support (USE_NTL=1 in CONFIG.mine)"); else if (N == 1) return 1; else @@ -418,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly) for (int i=0; i& elem) const */ } -void PPData::from_eval(vector& elem) const +void PPData::from_eval(vector&) const { // avoid warning - elem.empty(); throw not_implemented(); /* diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index b19dd62cf..dd3f8968d 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -17,15 +17,13 @@ PairwiseMachine::PairwiseMachine(Player& P) : { } -PairwiseMachine::PairwiseMachine(int argc, const char** argv) : - MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")), - other_pks(N.num_players(), {setup_p.params, 0}), - pk(other_pks[N.my_num()]), sk(pk) +RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) : + MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise")) { init(); } -void PairwiseMachine::init() +void RealPairwiseMachine::init() { if (use_gf2n) { @@ -63,7 +61,7 @@ PairwiseSetup& PairwiseMachine::setup() } template -void PairwiseMachine::setup_keys() +void RealPairwiseMachine::setup_keys() { auto& N = P; PairwiseSetup& s = setup(); @@ -84,10 +82,11 @@ void PairwiseMachine::setup_keys() if (i != N.my_num()) other_pks[i].unpack(os[i]); set_mac_key(s.alphai); + Share::MAC_Check::setup(P); } template -void PairwiseMachine::set_mac_key(T alphai) +void RealPairwiseMachine::set_mac_key(T alphai) { typedef typename T::FD FD; auto& N = P; @@ -142,5 +141,5 @@ void PairwiseMachine::check(Player& P) const bundle.compare(P); } -template void PairwiseMachine::setup_keys(); -template void PairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); diff --git a/FHEOffline/PairwiseMachine.h b/FHEOffline/PairwiseMachine.h index c2283443e..a8a0c649e 100644 --- a/FHEOffline/PairwiseMachine.h +++ b/FHEOffline/PairwiseMachine.h @@ -10,7 +10,7 @@ #include "FHEOffline/SimpleMachine.h" #include "FHEOffline/PairwiseSetup.h" -class PairwiseMachine : public MachineBase +class PairwiseMachine : public virtual MachineBase { public: PairwiseSetup setup_p; @@ -23,15 +23,6 @@ class PairwiseMachine : public MachineBase vector enc_alphas; PairwiseMachine(Player& P); - PairwiseMachine(int argc, const char** argv); - - void init(); - - template - void setup_keys(); - - template - void set_mac_key(T alphai); template PairwiseSetup& setup(); @@ -42,4 +33,18 @@ class PairwiseMachine : public MachineBase void check(Player& P) const; }; +class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine +{ +public: + RealPairwiseMachine(int argc, const char** argv); + + void init(); + + template + void setup_keys(); + + template + void set_mac_key(T alphai); +}; + #endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */ diff --git a/FHEOffline/SimpleGenerator.cpp b/FHEOffline/SimpleGenerator.cpp index b2701b2c5..be5ee2c19 100644 --- a/FHEOffline/SimpleGenerator.cpp +++ b/FHEOffline/SimpleGenerator.cpp @@ -12,7 +12,7 @@ template