diff --git a/package/CMakeLists.txt b/package/CMakeLists.txt index 587a3f8..4597f1d 100644 --- a/package/CMakeLists.txt +++ b/package/CMakeLists.txt @@ -50,7 +50,7 @@ FetchContent_Declare( # Download Gemmi FetchContent_Declare( gemmi-dependencies - URL http://www.ysbl.york.ac.uk/jsd523/gemmi-0.6.5.tar.gz + URL https://github.com/Dialpuri/gemmi-bundles/raw/refs/heads/main/gemmi-0.7.3.tar.gz ) FetchContent_MakeAvailable(clipper-dependencies mmdb2-dependencies fftw-dependencies ccp4-dependencies gemmi-dependencies) @@ -101,8 +101,10 @@ add_library( ${WRK_DIR}/src/cpp/sails-telemetry.cpp ${WRK_DIR}/src/cpp/sails-solvent.cpp ${WRK_DIR}/src/cpp/sails-wurcs.cpp - + ${WRK_DIR}/src/cpp/sails-predictions.cpp ${WRK_DIR}/src/cpp/sails-morph.cpp + ${WRK_DIR}/src/cpp/sails-score.cpp + # Density ${WRK_DIR}/src/cpp/density/sails-density.cpp ${WRK_DIR}/src/cpp/density/sails-xtal-density.cpp diff --git a/package/_pyproject.toml b/package/_pyproject.toml index 70b8d02..a509eb6 100644 --- a/package/_pyproject.toml +++ b/package/_pyproject.toml @@ -18,7 +18,10 @@ dependencies=[ 'tqdm', 'gemmi', 'numpy<2.0.0', - 'typing-extensions' + 'typing-extensions', + 'onnxruntime-gpu; platform_system != "Darwin"', + 'onnxruntime; platform_system == "Darwin"', + 'requests' ] [project.urls] @@ -32,6 +35,10 @@ sails-find = "sails.find:run" sails-test = "sails.test:run" sails-wurcs = "sails.wurcs:run" sails-morph = "sails.morph:run" +sails-predict = "sails.prediction.predict:run" +sails-install = "sails.install:run" +sails-clean = "sails.clean:run" +sails-validate = "sails.validate:run" [tool.scikit-build] # Protect the configuration against future changes in scikit-build-core diff --git a/package/gemmi/CMakeLists.txt b/package/gemmi/CMakeLists.txt index ce1608a..9451544 100644 --- a/package/gemmi/CMakeLists.txt +++ b/package/gemmi/CMakeLists.txt @@ -22,24 +22,31 @@ add_library(gemmi_cpp STATIC ${gemmi_src}/src/align.cpp ${gemmi_src}/src/assembly.cpp ${gemmi_src}/src/calculate.cpp + ${gemmi_src}/src/ccp4.cpp ${gemmi_src}/src/crd.cpp ${gemmi_src}/src/ddl.cpp + ${gemmi_src}/src/dssp.cpp ${gemmi_src}/src/eig3.cpp + ${gemmi_src}/src/fprime.cpp ${gemmi_src}/src/gz.cpp ${gemmi_src}/src/intensit.cpp ${gemmi_src}/src/json.cpp ${gemmi_src}/src/mmcif.cpp ${gemmi_src}/src/mmread_gz.cpp + ${gemmi_src}/src/monlib.cpp ${gemmi_src}/src/mtz.cpp ${gemmi_src}/src/mtz2cif.cpp ${gemmi_src}/src/polyheur.cpp + ${gemmi_src}/src/pdb.cpp ${gemmi_src}/src/read_cif.cpp ${gemmi_src}/src/resinfo.cpp ${gemmi_src}/src/riding_h.cpp + ${gemmi_src}/src/select.cpp ${gemmi_src}/src/sprintf.cpp + ${gemmi_src}/src/symmetry.cpp + ${gemmi_src}/src/to_json.cpp ${gemmi_src}/src/to_mmcif.cpp ${gemmi_src}/src/to_pdb.cpp - ${gemmi_src}/src/monlib.cpp ${gemmi_src}/src/topo.cpp ${gemmi_src}/src/xds_ascii.cpp ) @@ -63,7 +70,6 @@ set(gemmi_HEADERS ${gemmi_src}/include/gemmi/ccp4.hpp ${gemmi_src}/include/gemmi/cellred.hpp ${gemmi_src}/include/gemmi/chemcomp.hpp - ${gemmi_src}/include/gemmi/chemcomp_xyz.hpp ${gemmi_src}/include/gemmi/cif.hpp ${gemmi_src}/include/gemmi/cif2mtz.hpp ${gemmi_src}/include/gemmi/cifdoc.hpp @@ -72,6 +78,7 @@ set(gemmi_HEADERS ${gemmi_src}/include/gemmi/ddl.hpp ${gemmi_src}/include/gemmi/dencalc.hpp ${gemmi_src}/include/gemmi/dirwalk.hpp + ${gemmi_src}/include/gemmi/dssp.hpp ${gemmi_src}/include/gemmi/ecalc.hpp ${gemmi_src}/include/gemmi/eig3.hpp ${gemmi_src}/include/gemmi/elem.hpp @@ -86,20 +93,21 @@ set(gemmi_HEADERS ${gemmi_src}/include/gemmi/grid.hpp ${gemmi_src}/include/gemmi/gz.hpp ${gemmi_src}/include/gemmi/input.hpp + ${gemmi_src}/include/gemmi/intensit.hpp ${gemmi_src}/include/gemmi/interop.hpp ${gemmi_src}/include/gemmi/it92.hpp ${gemmi_src}/include/gemmi/iterator.hpp ${gemmi_src}/include/gemmi/json.hpp ${gemmi_src}/include/gemmi/levmar.hpp ${gemmi_src}/include/gemmi/linkhunt.hpp + ${gemmi_src}/include/gemmi/logger.hpp ${gemmi_src}/include/gemmi/math.hpp -# ${gemmi_src}/include/gemmi/merge.hpp ${gemmi_src}/include/gemmi/metadata.hpp - ${gemmi_src}/include/gemmi/mmcif.hpp ${gemmi_src}/include/gemmi/mmcif_impl.hpp + ${gemmi_src}/include/gemmi/mmcif.hpp ${gemmi_src}/include/gemmi/mmdb.hpp - ${gemmi_src}/include/gemmi/mmread.hpp ${gemmi_src}/include/gemmi/mmread_gz.hpp + ${gemmi_src}/include/gemmi/mmread.hpp ${gemmi_src}/include/gemmi/model.hpp ${gemmi_src}/include/gemmi/modify.hpp ${gemmi_src}/include/gemmi/monlib.hpp @@ -108,17 +116,15 @@ set(gemmi_HEADERS ${gemmi_src}/include/gemmi/neighbor.hpp ${gemmi_src}/include/gemmi/neutron92.hpp ${gemmi_src}/include/gemmi/numb.hpp - ${gemmi_src}/include/gemmi/pdb.hpp ${gemmi_src}/include/gemmi/pdb_id.hpp + ${gemmi_src}/include/gemmi/pdb.hpp ${gemmi_src}/include/gemmi/pirfasta.hpp ${gemmi_src}/include/gemmi/polyheur.hpp ${gemmi_src}/include/gemmi/qcp.hpp ${gemmi_src}/include/gemmi/read_cif.hpp - ${gemmi_src}/include/gemmi/read_map.hpp ${gemmi_src}/include/gemmi/recgrid.hpp ${gemmi_src}/include/gemmi/reciproc.hpp ${gemmi_src}/include/gemmi/refln.hpp - ${gemmi_src}/include/gemmi/remarks.hpp ${gemmi_src}/include/gemmi/resinfo.hpp ${gemmi_src}/include/gemmi/riding_h.hpp ${gemmi_src}/include/gemmi/scaling.hpp @@ -126,6 +132,7 @@ set(gemmi_HEADERS ${gemmi_src}/include/gemmi/seqalign.hpp ${gemmi_src}/include/gemmi/seqid.hpp ${gemmi_src}/include/gemmi/seqtools.hpp + ${gemmi_src}/include/gemmi/serialize.hpp ${gemmi_src}/include/gemmi/sfcalc.hpp ${gemmi_src}/include/gemmi/small.hpp ${gemmi_src}/include/gemmi/smcif.hpp @@ -146,14 +153,14 @@ set(gemmi_HEADERS ${gemmi_src}/include/gemmi/util.hpp ${gemmi_src}/include/gemmi/version.hpp ${gemmi_src}/include/gemmi/xds_ascii.hpp - + ${gemmi_src}/include/gemmi/xds2mtz.hpp ) set(gemmi_third_party-headers_HEADERS ${gemmi_src}/include/gemmi/third_party/fast_float.h ${gemmi_src}/include/gemmi/third_party/pocketfft_hdronly.h -# ${gemmi_src}/include/gemmi/third_party/sajson.h + # ${gemmi_src}/include/gemmi/third_party/sajson.h ${gemmi_src}/include/gemmi/third_party/tinydir.h ) @@ -198,7 +205,7 @@ set(gemmi_third_party_tao_pegtl_analysis-headers_HEADERS ${gemmi_src}/include/gemmi/third_party/tao/pegtl/analysis/generic.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/analysis/grammar_info.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/analysis/insert_guard.hpp - ${gemmi_src}/include/gemmi/third_party/tao/pegtl/analysis/insert_rules.hpp + # ${gemmi_src}/include/gemmi/third_party/tao/pegtl/analysis/insert_rules.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/analysis/rule_info.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/analysis/rule_type.hpp ) @@ -217,7 +224,7 @@ set(gemmi_third_party_tao_pegtl_internal-headers_HEADERS ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/bof.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/bol.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/bump_help.hpp - ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/bump_impl.hpp + # ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/bump_impl.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/bytes.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/control.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/cr_crlf_eol.hpp @@ -240,7 +247,7 @@ set(gemmi_third_party_tao_pegtl_internal-headers_HEADERS ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/eof.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/eol.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/eolf.hpp - ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/file_mapper.hpp + # ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/file_mapper.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/file_opener.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/file_reader.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/has_apply.hpp @@ -262,7 +269,7 @@ set(gemmi_third_party_tao_pegtl_internal-headers_HEADERS ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/list_tail.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/list_tail_pad.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/marker.hpp - ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/minus.hpp + # ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/minus.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/must.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/not_at.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/one.hpp @@ -284,7 +291,7 @@ set(gemmi_third_party_tao_pegtl_internal-headers_HEADERS ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/rep_opt.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/require.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/result_on_found.hpp - ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/rule_conjunction.hpp + # ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/rule_conjunction.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/rules.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/seq.hpp ${gemmi_src}/include/gemmi/third_party/tao/pegtl/internal/skip_control.hpp diff --git a/package/pyproject.toml b/package/pyproject.toml index bd544d5..4e96959 100644 --- a/package/pyproject.toml +++ b/package/pyproject.toml @@ -18,11 +18,15 @@ dependencies=[ 'tqdm', 'gemmi', 'numpy<2.0.0', - 'typing-extensions' + 'typing-extensions', + 'onnxruntime-gpu; platform_system != "Darwin"', + 'onnxruntime; platform_system == "Darwin"', + 'requests', + 'nanobind==2.4.0' ] [tool.setuptools] -packages = ["sails"] +packages = ["sails", "sails.prediction"] package-dir = {"" = "src"} include-package-data = true @@ -40,6 +44,10 @@ sails-find = "sails.find:run" sails-test = "sails.test:run" sails-wurcs = "sails.wurcs:run" sails-morph = "sails.morph:run" +sails-predict = "sails.prediction.predict:run" +sails-install = "sails.install:run" +sails-clean = "sails.clean:run" +sails-validate = "sails.validate:run" [tool.scikit-build] # Protect the configuration against future changes in scikit-build-core diff --git a/package/scripts/compare_structures.py b/package/scripts/compare_structures.py index 68e2382..8cc6cd6 100644 --- a/package/scripts/compare_structures.py +++ b/package/scripts/compare_structures.py @@ -2,6 +2,7 @@ import gemmi import json import numpy as np +from pprint import pprint def load_data_file(filename): @@ -20,11 +21,11 @@ def format_residue(chain: gemmi.Chain, residue: gemmi.Residue): def main(args): - data = load_data_file("package/data/data.json") + data = load_data_file("package/src/sails/data/data.json") structure = gemmi.read_structure(args.model) reference = gemmi.read_structure(args.reference) - ns = gemmi.NeighborSearch(structure, max_radius=1).populate() + ns = gemmi.NeighborSearch(structure, max_radius=1.5).populate() output = {} @@ -84,6 +85,8 @@ def main(args): percentage_modelled = 100 * modelled / total_sugars print(f"Percentage Modelled {percentage_modelled:.2f}") + pprint(output) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/package/src/bindings/python_sails.cpp b/package/src/bindings/python_sails.cpp index f09bf77..7b14782 100644 --- a/package/src/bindings/python_sails.cpp +++ b/package/src/bindings/python_sails.cpp @@ -38,6 +38,58 @@ NB_MODULE(sails_module, m) { .def_rw("delfwt_phdelwt", &Sails::Reflection::delfwt_phdelwt); + nb::class_(m, "ResidueId") + .def_rw("seqid", &gemmi::ResidueId::seqid) + .def_rw("segment", &gemmi::ResidueId::segment) + .def_rw("name", &gemmi::ResidueId::name); + + nb::class_(m, "AtomAddress") + .def(nb::init<>()) + .def(nb::init(), + nb::arg("chain_name"), nb::arg("res_id"), nb::arg("atom_name"), nb::arg("altloc") = '\0') + .def(nb::init(), + nb::arg("chain_name"), nb::arg("seq_id"), nb::arg("res"), nb::arg("atom_name"), nb::arg("altloc") = '\0') + .def_rw("chain_name", &gemmi::AtomAddress::chain_name) + .def_rw("res_id", &gemmi::AtomAddress::res_id) + .def_rw("atom_name", &gemmi::AtomAddress::atom_name) + .def_rw("altloc", &gemmi::AtomAddress::altloc) + .def("__eq__", &gemmi::AtomAddress::operator==) + .def("__str__", &gemmi::AtomAddress::str); + + nb::enum_(m, "ConnectionType") + .value("Covale", gemmi::Connection::Type::Covale) + .value("Disulf", gemmi::Connection::Type::Disulf) + .value("Hydrog", gemmi::Connection::Type::Hydrog) + .value("MetalC", gemmi::Connection::Type::MetalC) + .value("Unknown", gemmi::Connection::Type::Unknown); + + nb::bind_vector >(m, "Connections"); + + nb::class_(m, "Connection") + .def(nb::init<>()) + .def_rw("name", &gemmi::Connection::name) + .def_rw("link_id", &gemmi::Connection::link_id) + .def_rw("type", &gemmi::Connection::type) + .def_rw("asu", &gemmi::Connection::asu) + .def_rw("partner1", &gemmi::Connection::partner1) + .def_rw("partner2", &gemmi::Connection::partner2) + .def_rw("reported_distance", &gemmi::Connection::reported_distance) + .def_prop_rw( + "reported_sym", + [](gemmi::Connection &self) { + return nb::cast(std::array{ + self.reported_sym[0], + self.reported_sym[1], + self.reported_sym[2], + self.reported_sym[3] + }); + }, + [](gemmi::Connection &self, const std::array &arr) { + for (size_t i = 0; i < 4; ++i) + self.reported_sym[i] = arr[i]; + }); + + // gemmi Structure nb::class_(m, "Structure") .def(nb::init<>()) @@ -45,6 +97,8 @@ NB_MODULE(sails_module, m) { .def("cell", [](const gemmi::Structure &structure) { return Sails::Cell(structure.cell); }) + .def_rw("connections", &gemmi::Structure::connections) + .def_rw("spacegroup_hm", &gemmi::Structure::spacegroup_hm) .def("set_cell", [](gemmi::Structure &structure, const Sails::Cell &cell) { structure.cell = gemmi::UnitCell(cell.a, cell.b, cell.c, cell.alpha, cell.beta, cell.gamma); @@ -54,7 +108,7 @@ NB_MODULE(sails_module, m) { nb::class_(m, "Model") .def(nb::init<>()) - .def_rw("name", &gemmi::Model::name) + .def_rw("num", &gemmi::Model::num) .def_rw("chains", &gemmi::Model::chains); nb::bind_vector >(m, "Chains"); @@ -139,6 +193,7 @@ NB_MODULE(sails_module, m) { .def_rw("chain_idx", &Sails::Glycosite::chain_idx) .def_rw("residue_idx", &Sails::Glycosite::residue_idx) .def_rw("atom_idx", &Sails::Glycosite::atom_idx); + nb::bind_vector >(m, "GlycoSites"); nb::class_(m, "Dot") .def(nb::init()) @@ -188,14 +243,36 @@ NB_MODULE(sails_module, m) { "mtz"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); m.def("n_glycosylate", - nb::overload_cast &, int, std::string &, bool>(&n_glycosylate), - "structure"_a, "grid"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + nb::overload_cast &, float, int, std::string &, bool>(&n_glycosylate), + "structure"_a, "grid"_a, "resolution"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); m.def("c_glycosylate", - nb::overload_cast &, int, std::string &, bool>(&c_glycosylate), - "structure"_a, "grid"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + nb::overload_cast &, float, int, std::string &, bool>(&c_glycosylate), + "structure"_a, "grid"_a, "resolution"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); m.def("o_mannosylate", - nb::overload_cast &, int, std::string &, bool>(&o_mannosylate), - "structure"_a, "grid"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + nb::overload_cast &, float, int, std::string &, bool>(&o_mannosylate), + "structure"_a, "grid"_a, "resolution"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + + m.def("auto_glycosylate", + nb::overload_cast&, gemmi::Grid<>&, int, std::string &, bool>(&auto_glycosylate), "structure"_a, + "mtz"_a, "glycan_grid"_a, "protein_grid"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + + m.def("auto_glycosylate", + nb::overload_cast &, float, gemmi::Grid<>&, gemmi::Grid<>&, int, std::string &, bool>(&auto_glycosylate), "structure"_a, + "grid"_a, "resolution"_a, "glycan_grid"_a, "protein_grid"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + + m.def("glycosylate_site", + nb::overload_cast(&glycosylate_site), "structure"_a, + "mtz"_a, "chain"_a, "seqid"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + + m.def("glycosylate_site", + nb::overload_cast &, float, std::string&, int, int, std::string &, bool>(&glycosylate_site), "structure"_a, + "grid"_a, "resolution"_a, "chain"_a, "seqid"_a, "cycles"_a, "resource_dir"_a, "verbose"_a); + + m.def("identify_predicted_sites", nb::overload_cast&, std::string &>(&identify_predicted_sites), + "structure"_a, "glycan_grid"_a, "resource_dir"_a); + m.def("identify_predicted_sites", nb::overload_cast&, gemmi::Grid<>&, bool, std::string &>(&identify_predicted_sites), + "structure"_a, "glycan_grid"_a, "protein_grid"_a, "use_glycan"_a, "resource_dir"_a); + m.def("find_all_wurcs", &find_all_wurcs, "structure"_a, "resource_dir"_a); m.def("find_wurcs", &find_wurcs, "structure"_a, "chain"_a, "seqid"_a, "resource_dir"_a); @@ -203,6 +280,13 @@ NB_MODULE(sails_module, m) { m.def("morph", &morph, "structure"_a, "wurcs"_a, "chain"_a, "seqid"_a, "resource_dir"_a); + // XRAY + m.def("validate", nb::overload_cast(&validate), "structure"_a, "mtz"_a, "remove"_a, "threshold"_a, "resource_dir"_a); + m.def("validate_site", nb::overload_cast(&validate_site), "structure"_a, "mtz"_a, "chain"_a, "seqid"_a, "remove"_a, "threshold"_a, "resource_dir"_a); + + // EM + m.def("validate", nb::overload_cast &, float, bool, float, bool, std::string &>(&validate), "structure"_a, "grid"_a, "resolution"_a, "remove"_a, "threshold"_a, "use_q"_a, "resource_dir"_a); + m.def("test_snfg", &test); m.def("get_snfg", &get_snfg, "chain"_a, "seqid"_a , "structure"_a, "resource_dir"_a); diff --git a/package/src/cpp/density/sails-density.cpp b/package/src/cpp/density/sails-density.cpp index ed5ee83..8b4481c 100644 --- a/package/src/cpp/density/sails-density.cpp +++ b/package/src/cpp/density/sails-density.cpp @@ -8,6 +8,8 @@ #include #include +#include "src/include/sails-score.h" + double Sails::Density::score_residue(gemmi::Residue &residue, const DensityScoreMethod &method) { switch (method) { @@ -17,8 +19,8 @@ double Sails::Density::score_residue(gemmi::Residue &residue, const DensityScore return rscc_score(residue); case rsr: return rsr_score(residue); - case dds: - return difference_density_score(residue); + case q: + return q_score(residue); default: return -1; } @@ -32,8 +34,8 @@ double Sails::Density::score_result(SuperpositionResult& result) { return rscc_score(result); case rsr: return rsr_score(result); - case dds: - return difference_density_score(result.new_residue); + case q: + return q_score(result.new_residue); default: return -1; } @@ -46,48 +48,65 @@ float Sails::Density::atomwise_score(const gemmi::Residue &residue) const { }) / (residue.atoms.size()); } -gemmi::Grid<> Sails::Density::calculate_density_for_box(gemmi::Residue &residue, gemmi::Box &box) const { - - gemmi::DensityCalculator, float> density_calculator; - - gemmi::Position size = box.get_size(); - gemmi::UnitCell dummy_cell = {size.x, size.y, size.z, 90, 90, 90}; - density_calculator.grid.unit_cell = dummy_cell; - density_calculator.grid.nu = size.x; - density_calculator.grid.nv = size.y; - density_calculator.grid.nw = size.z; - density_calculator.grid.spacegroup = get_work_grid()->spacegroup; - density_calculator.grid.axis_order = get_work_grid()->axis_order; - - density_calculator.d_min = 1; - density_calculator.initialize_grid(); - for (auto &atom: residue.atoms) { - density_calculator.add_atom_density_to_grid(atom); - } - density_calculator.grid.symmetrize_sum(); - return density_calculator.grid; -} - -gemmi::Grid<> Sails::Density::calculate_density_for_grid(gemmi::Residue &residue) const { - - gemmi::DensityCalculator, float> density_calculator; - - density_calculator.grid.copy_metadata_from(*get_work_grid()); - density_calculator.grid.spacing[0] = get_work_grid()->spacing[0]; - density_calculator.grid.spacing[1] = get_work_grid()->spacing[1]; - density_calculator.grid.spacing[2] = get_work_grid()->spacing[2]; - - density_calculator.d_min = get_resolution(); - density_calculator.initialize_grid(); - for (auto &atom: residue.atoms) { - density_calculator.add_atom_density_to_grid(atom); - } - density_calculator.grid.symmetrize_sum(); - auto x = density_calculator.grid; - return std::move(x); -} - -float Sails::Density::calculate_rscc(std::vector obs_values, std::vector calc_values) { +// gemmi::Grid<> Sails::Density::calculate_density_for_box(gemmi::Residue &residue, gemmi::Box &box) const { +// +// gemmi::DensityCalculator, float> density_calculator; +// +// gemmi::Position size = box.get_size(); +// gemmi::UnitCell dummy_cell = {size.x, size.y, size.z, 90, 90, 90}; +// density_calculator.grid.unit_cell = dummy_cell; +// density_calculator.grid.nu = size.x; +// density_calculator.grid.nv = size.y; +// density_calculator.grid.nw = size.z; +// density_calculator.grid.spacegroup = get_work_grid()->spacegroup; +// density_calculator.grid.axis_order = get_work_grid()->axis_order; +// +// density_calculator.d_min = 1; +// density_calculator.initialize_grid(); +// for (auto &atom: residue.atoms) { +// density_calculator.add_atom_density_to_grid(atom); +// } +// density_calculator.grid.symmetrize_sum(); +// return density_calculator.grid; +// } +// +// gemmi::Grid<> Sails::Density::calculate_density_for_grid(gemmi::Residue &residue) const { +// +// gemmi::DensityCalculator, float> density_calculator; +// +// density_calculator.grid.copy_metadata_from(*get_best_grid()); +// density_calculator.grid.spacing[0] = get_best_grid()->spacing[0]; +// density_calculator.grid.spacing[1] = get_best_grid()->spacing[1]; +// density_calculator.grid.spacing[2] = get_best_grid()->spacing[2]; +// +// density_calculator.d_min = get_resolution(); +// density_calculator.initialize_grid(); +// for (auto &atom: residue.atoms) { +// density_calculator.add_atom_density_to_grid(atom); +// } +// density_calculator.grid.symmetrize_sum(); +// auto x = density_calculator.grid; +// return std::move(x); +// } +// +// gemmi::Grid<> Sails::Density::calculate_density_for_structure(gemmi::Structure &structure) const { +// gemmi::DensityCalculator, float> density_calculator; +// +// density_calculator.grid.copy_metadata_from(*get_best_grid()); +// density_calculator.grid.spacing[0] = get_best_grid()->spacing[0]; +// density_calculator.grid.spacing[1] = get_best_grid()->spacing[1]; +// density_calculator.grid.spacing[2] = get_best_grid()->spacing[2]; +// +// density_calculator.d_min = get_resolution(); +// density_calculator.initialize_grid(); +// density_calculator.add_model_density_to_grid(structure.models[0]); +// density_calculator.grid.symmetrize_sum(); +// auto x = density_calculator.grid; +// return std::move(x); +// } + +template +T Sails::Density::calculate_rscc(std::vector obs_values, std::vector calc_values) { if (obs_values.size() != calc_values.size()) throw std::runtime_error("RSCC obs and calc lists are different sizes"); @@ -117,6 +136,8 @@ float Sails::Density::calculate_rscc(std::vector obs_values, std::vector< if (denominator == 0.0f) throw std::runtime_error("RSCC Denominator is 0"); return numerator / denominator; } +template float Sails::Density::calculate_rscc(std::vector obs_values, std::vector calc_values); +template double Sails::Density::calculate_rscc(std::vector obs_values, std::vector calc_values); float Sails::Density::rscc_score(gemmi::Residue &residue) const { @@ -126,11 +147,14 @@ float Sails::Density::rscc_score(gemmi::Residue &residue) const { for (auto &atom: residue.atoms) { box.extend(atom.pos); } - box.add_margin(1); + box.add_margin(2); // gemmi::Grid<> calc = calculate_density_for_box(residue, box); gemmi::Grid<> calc = calculate_density_for_grid(residue); + gemmi::Model model = Utils::create_model(residue); + gemmi::NeighborSearch ns = {model, get_best_grid()->unit_cell, 2}; + ns.populate(); // gemmi::Ccp4<> m; // m.grid = calc; // m.update_ccp4_header(); @@ -139,25 +163,26 @@ float Sails::Density::rscc_score(gemmi::Residue &residue) const { // std::vector rs = {residue}; // Utils::save_residues_to_file(rs, "res.pdb"); - - const gemmi::Position max = box.maximum; const gemmi::Position min = box.minimum; + const gemmi::Position max = box.maximum; std::vector obs_values = {}; std::vector calc_values = {}; - constexpr double step_size = 0.5; - for (double x = min.x; x <= max.x; x += step_size) { - for (double y = min.y; y <= max.y; y += step_size) { - for (double z = min.z; z <= max.z; z += step_size) { + for (double x = min.x; x <= max.x; x += get_best_grid()->spacing[0]) { + for (double y = min.y; y <= max.y; y += get_best_grid()->spacing[1]) { + for (double z = min.z; z <= max.z; z += get_best_grid()->spacing[2]) { gemmi::Position position = {x, y, z}; - obs_values.emplace_back(get_best_grid()->interpolate_value(position)); - calc_values.emplace_back(calc.interpolate_value(position)); + auto nearest_atom = ns.find_atoms(position, '*', 0, 2); + if (!nearest_atom.empty()) { + obs_values.emplace_back(get_best_grid()->interpolate_value(position)); + calc_values.emplace_back(calc.interpolate_value(position)); + } } } } - return calculate_rscc(obs_values, calc_values); + return calculate_rscc(obs_values, calc_values); } float Sails::Density::rscc_score(SuperpositionResult &result) { @@ -196,7 +221,7 @@ float Sails::Density::rscc_score(SuperpositionResult &result) { } } - return calculate_rscc(obs_values, calc_values); + return calculate_rscc(obs_values, calc_values); } float Sails::Density::rsr_score(gemmi::Residue &residue) { @@ -270,30 +295,44 @@ float Sails::Density::rsr_score(SuperpositionResult &result) { return numerator / denominator; } -float Sails::Density::difference_density_score(gemmi::Residue &residue) const { - gemmi::Box box; - for (auto &atom: residue.atoms) { - box.extend(atom.pos); - } +int Sails::Density::check_difference_density(gemmi::Residue &residue, std::pair map_stats) const { - const gemmi::Position max = box.maximum; - const gemmi::Position min = box.minimum; + float threshold = map_stats.first - 2 * map_stats.second; - float sum = 0.0f; - int points = 0; - constexpr double step_size = 0.5; - for (double x = min.x; x <= max.x; x += step_size) { - for (double y = min.y; y <= max.y; y += step_size) { - for (double z = min.z; z <= max.z; z += step_size) { - gemmi::Position position = {x, y, z}; - float value = get_difference_grid()->interpolate_value(position); - sum += abs(value); - points++; - } + std::set ring_atoms = { + "C1", "C2", "C3", "C4", "C5", "O5" + }; + int i = 0; + for (auto & atom : residue.atoms) { + // if (ring_atoms.count(atom.name) == 0) continue; + if (get_difference_grid()->interpolate_value(atom.pos) < threshold) { + i++; } } - - return sum / points; + return i; + // gemmi::Box box; + // for (auto &atom: residue.atoms) { + // box.extend(atom.pos); + // } + // + // const gemmi::Position max = box.maximum; + // const gemmi::Position min = box.minimum; + // + // float sum = 0.0f; + // int points = 0; + // constexpr double step_size = 0.5; + // for (double x = min.x; x <= max.x; x += step_size) { + // for (double y = min.y; y <= max.y; y += step_size) { + // for (double z = min.z; z <= max.z; z += step_size) { + // gemmi::Position position = {x, y, z}; + // float value = get_difference_grid()->interpolate_value(position); + // sum += abs(value); + // points++; + // } + // } + // } + // + // return sum / points; } float Sails::Density::score_atomic_position(const gemmi::Atom &atom) const { @@ -304,3 +343,45 @@ float Sails::Density::score_atomic_position(const gemmi::Atom &atom) const { float Sails::Density::score_position(const gemmi::Position &pos) const { return get_work_grid()->interpolate_value(pos); } + +std::pair Sails::Density::calculate_map_statistics(const gemmi::Grid<> *grid) const { + const float sum = std::accumulate(grid->data.begin(), grid->data.end(), 0.0f); + float mean = sum / grid->data.size(); + + float sq_sum = std::accumulate(grid->data.begin(), grid->data.end(), 0.0, + [mean](const double acc, const double x) { + const double diff = x - mean; + return acc + diff * diff; + }); + + float stdev = std::sqrt(sq_sum / grid->data.size()); + + return std::make_pair(mean, stdev); +} + +double Sails::Density::q_score(gemmi::Residue &residue) { + auto [mean, stddev] = get_map_stats(); + + const float A = mean + (10 * stddev); + const float B = mean - stddev; + constexpr float sigma = 0.6; + constexpr int N = 8; + + gemmi::Model model = Utils::create_model(residue); + gemmi::NeighborSearch ns = {model, get_best_grid()->unit_cell, 2}; + ns.populate(); + + std::vector residue_q_scores = {}; + + for (int a = 0; a < residue.atoms.size(); a++) { + Glycosite atom_site = {0, 0, 0, a}; + double atom_q = Score::QScore::calculate_q_score(residue.atoms[a].pos, atom_site, get_work_grid(), + ns, A, B, sigma, N); + residue_q_scores.emplace_back(atom_q); + } + + const double mean_residue_q_score = std::accumulate(residue_q_scores.begin(), residue_q_scores.end(), 0.0) + / static_cast(residue.atoms.size()); + + return mean_residue_q_score; +} diff --git a/package/src/cpp/density/sails-em-density.cpp b/package/src/cpp/density/sails-em-density.cpp index 3b57af9..b968858 100644 --- a/package/src/cpp/density/sails-em-density.cpp +++ b/package/src/cpp/density/sails-em-density.cpp @@ -5,6 +5,63 @@ #include "../../include/density/sails-density.h" #include "../../include/density/sails-em-density.h" -Sails::EMDensity::EMDensity(gemmi::Grid<> &grid) { +Sails::EMDensity::EMDensity(gemmi::Grid<> &grid, float resolution) { m_grid = grid; + m_resolution = resolution; +} + +gemmi::Grid<> Sails::EMDensity::calculate_density_for_box(gemmi::Residue &residue, + gemmi::Box &box) const { + gemmi::DensityCalculator, float> density_calculator; + + gemmi::Position size = box.get_size(); + gemmi::UnitCell dummy_cell = {size.x, size.y, size.z, 90, 90, 90}; + density_calculator.grid.unit_cell = dummy_cell; + density_calculator.grid.nu = size.x; + density_calculator.grid.nv = size.y; + density_calculator.grid.nw = size.z; + density_calculator.grid.spacegroup = get_work_grid()->spacegroup; + density_calculator.grid.axis_order = get_work_grid()->axis_order; + + density_calculator.d_min = 1; + density_calculator.initialize_grid(); + for (auto &atom: residue.atoms) { + density_calculator.add_atom_density_to_grid(atom); + } + density_calculator.grid.symmetrize_sum(); + return density_calculator.grid; +} + +gemmi::Grid<> Sails::EMDensity::calculate_density_for_grid(gemmi::Residue &residue) const { + gemmi::DensityCalculator, float> density_calculator; + + density_calculator.grid.copy_metadata_from(*get_best_grid()); + density_calculator.grid.spacing[0] = get_best_grid()->spacing[0]; + density_calculator.grid.spacing[1] = get_best_grid()->spacing[1]; + density_calculator.grid.spacing[2] = get_best_grid()->spacing[2]; + + density_calculator.d_min = get_resolution(); + density_calculator.initialize_grid(); + for (auto &atom: residue.atoms) { + density_calculator.add_atom_density_to_grid(atom); + } + density_calculator.grid.symmetrize_sum(); + auto x = density_calculator.grid; + return std::move(x); +} + +gemmi::Grid<> Sails::EMDensity::calculate_density_for_structure(gemmi::Structure &structure) const { + gemmi::DensityCalculator, float> density_calculator; + + density_calculator.grid.copy_metadata_from(*get_best_grid()); + density_calculator.grid.spacing[0] = get_best_grid()->spacing[0]; + density_calculator.grid.spacing[1] = get_best_grid()->spacing[1]; + density_calculator.grid.spacing[2] = get_best_grid()->spacing[2]; + + density_calculator.d_min = get_resolution(); + density_calculator.initialize_grid(); + density_calculator.add_model_density_to_grid(structure.models[0]); + density_calculator.grid.symmetrize_sum(); + auto x = density_calculator.grid; + return std::move(x); } diff --git a/package/src/cpp/density/sails-xtal-density.cpp b/package/src/cpp/density/sails-xtal-density.cpp index 4f6b360..dc06b8b 100644 --- a/package/src/cpp/density/sails-xtal-density.cpp +++ b/package/src/cpp/density/sails-xtal-density.cpp @@ -17,6 +17,10 @@ Sails::XtalDensity::XtalDensity(gemmi::Mtz &mtz, const std::string& F, const std load_hkl(F, SIGF); } +void Sails::XtalDensity::load_map_coefficients(const std::string &fwt, const std::string &phwt) { + m_grid = load_grid(m_mtz, fwt, phwt, false); +} + void Sails::XtalDensity::initialise_hkl() { m_resolution = clipper::Resolution(m_mtz.resolution_high()); @@ -60,7 +64,7 @@ void Sails::XtalDensity::load_hkl(const std::string &f, const std::string &sig_f gemmi::Grid<> Sails::XtalDensity::load_grid(const gemmi::Mtz &mtz, const std::string &f_col, const std::string &phi_col, bool normalise) { constexpr std::array null_size = {0, 0, 0}; - constexpr double sample_rate = 0; + constexpr double sample_rate = 3; constexpr auto order = gemmi::AxisOrder::XYZ; const gemmi::Mtz::Column &f = mtz.get_column_with_label(f_col); @@ -93,6 +97,62 @@ void Sails::XtalDensity::form_atom_list(const gemmi::Structure &structure, std:: } +gemmi::Grid<> Sails::XtalDensity::calculate_density_for_box(gemmi::Residue &residue, + gemmi::Box &box) const { + gemmi::DensityCalculator, float> density_calculator; + + gemmi::Position size = box.get_size(); + gemmi::UnitCell dummy_cell = {size.x, size.y, size.z, 90, 90, 90}; + density_calculator.grid.unit_cell = dummy_cell; + density_calculator.grid.nu = size.x; + density_calculator.grid.nv = size.y; + density_calculator.grid.nw = size.z; + density_calculator.grid.spacegroup = get_work_grid()->spacegroup; + density_calculator.grid.axis_order = get_work_grid()->axis_order; + + density_calculator.d_min = 1; + density_calculator.initialize_grid(); + for (auto &atom: residue.atoms) { + density_calculator.add_atom_density_to_grid(atom); + } + density_calculator.grid.symmetrize_sum(); + return density_calculator.grid; +} + +gemmi::Grid<> Sails::XtalDensity::calculate_density_for_grid(gemmi::Residue &residue) const { + gemmi::DensityCalculator, float> density_calculator; + + density_calculator.grid.copy_metadata_from(*get_best_grid()); + density_calculator.grid.spacing[0] = get_best_grid()->spacing[0]; + density_calculator.grid.spacing[1] = get_best_grid()->spacing[1]; + density_calculator.grid.spacing[2] = get_best_grid()->spacing[2]; + + density_calculator.d_min = get_resolution(); + density_calculator.initialize_grid(); + for (auto &atom: residue.atoms) { + density_calculator.add_atom_density_to_grid(atom); + } + density_calculator.grid.symmetrize_sum(); + auto x = density_calculator.grid; + return std::move(x); +} + +gemmi::Grid<> Sails::XtalDensity::calculate_density_for_structure(gemmi::Structure &structure) const { + gemmi::DensityCalculator, float> density_calculator; + + density_calculator.grid.copy_metadata_from(*get_best_grid()); + density_calculator.grid.spacing[0] = get_best_grid()->spacing[0]; + density_calculator.grid.spacing[1] = get_best_grid()->spacing[1]; + density_calculator.grid.spacing[2] = get_best_grid()->spacing[2]; + + density_calculator.d_min = get_resolution(); + density_calculator.initialize_grid(); + density_calculator.add_model_density_to_grid(structure.models[0]); + density_calculator.grid.symmetrize_sum(); + auto x = density_calculator.grid; + return std::move(x); +} + void Sails::XtalDensity::recalculate_map(gemmi::Structure &structure) { std::vector atoms; form_atom_list(structure, atoms); @@ -136,12 +196,19 @@ void Sails::XtalDensity::recalculate_map(gemmi::Structure &structure) { recalculated_data.emplace_back(hkl.h()); recalculated_data.emplace_back(hkl.k()); recalculated_data.emplace_back(hkl.l()); - recalculated_data.emplace_back(clipper::Util::rad2d(fobs_reflection.f())); - recalculated_data.emplace_back(clipper::Util::rad2d(fobs_reflection.sigf())); - recalculated_data.emplace_back(clipper::Util::rad2d(fbest_reflection.f())); + recalculated_data.emplace_back(fobs_reflection.f()); + recalculated_data.emplace_back(fobs_reflection.sigf()); + recalculated_data.emplace_back(fbest_reflection.f()); recalculated_data.emplace_back(clipper::Util::rad2d(fbest_reflection.phi())); - recalculated_data.emplace_back(clipper::Util::rad2d(fdiff_reflection.f())); + recalculated_data.emplace_back(fdiff_reflection.f()); recalculated_data.emplace_back(clipper::Util::rad2d(fdiff_reflection.phi())); + + // recalculated_data.emplace_back(clipper::Util::rad2d(fobs_reflection.f())); + // recalculated_data.emplace_back(clipper::Util::rad2d(fobs_reflection.sigf())); + // recalculated_data.emplace_back(clipper::Util::rad2d(fbest_reflection.f())); + // recalculated_data.emplace_back(clipper::Util::rad2d(fbest_reflection.phi())); + // recalculated_data.emplace_back(clipper::Util::rad2d(fdiff_reflection.f())); + // recalculated_data.emplace_back(clipper::Util::rad2d(fdiff_reflection.phi())); } gemmi::Mtz new_mtz; @@ -160,7 +227,7 @@ void Sails::XtalDensity::recalculate_map(gemmi::Structure &structure) { m_mtz = std::move(new_mtz); m_grid = load_grid(m_mtz, "FWT", "PHWT", false); - m_difference_grid = load_grid(m_mtz, "DELFWT", "PHDELWT", true); + m_difference_grid = load_grid(m_mtz, "DELFWT", "PHDELWT", false); } void Sails::XtalDensity::calculate_po_pc_map(gemmi::Structure &structure) { diff --git a/package/src/cpp/sails-cif.cpp b/package/src/cpp/sails-cif.cpp index abb4671..c3300b0 100644 --- a/package/src/cpp/sails-cif.cpp +++ b/package/src/cpp/sails-cif.cpp @@ -29,3 +29,38 @@ std::vector Sails::generate_link_records(gemmi::Structure *s } return links; } + +void Sails::add_links_to_structure(gemmi::Structure *structure, std::vector &link_records) { + + std::vector connections; + + for (auto& link_record : link_records) { + gemmi::Connection connection; + connection.type = gemmi::Connection::Covale; + gemmi::AtomAddress a1; + a1.chain_name = link_record.chain1.name; + gemmi::ResidueId resid1; + resid1.name = link_record.residue1.name; + resid1.seqid = link_record.residue1.seqid; + a1.res_id = resid1; + a1.atom_name = link_record.atom1.name; + + gemmi::AtomAddress a2; + a2.chain_name = link_record.chain2.name; + gemmi::ResidueId resid2; + resid2.name = link_record.residue2.name; + resid2.seqid = link_record.residue2.seqid; + a2.res_id = resid2; + a2.atom_name = link_record.atom2.name; + + double distance = (link_record.atom1.pos - link_record.atom2.pos).length(); + + connection.partner1 = a1; + connection.partner2 = a2; + connection.reported_distance = distance; + connections.push_back(connection); + + } + + structure->connections = connections; +} diff --git a/package/src/cpp/sails-glycan.cpp b/package/src/cpp/sails-glycan.cpp index f49f329..d246de8 100644 --- a/package/src/cpp/sails-glycan.cpp +++ b/package/src/cpp/sails-glycan.cpp @@ -89,7 +89,7 @@ void Sails::Glycan::bfs(Sails::Sugar *root) { } } -void Sails::Glycan::dfs(Sugar *current_sugar, std::vector &terminal_sugars, int depth = 0) { +void Sails::Glycan::dfs_terminal(Sugar *current_sugar, std::vector &terminal_sugars, int depth = 0) { std::set &sugar_set = adjacency_list[current_sugar]; if (sugar_set.empty()) { current_sugar->depth = depth; @@ -98,7 +98,7 @@ void Sails::Glycan::dfs(Sugar *current_sugar, std::vector &terminal_sug for (Sugar *sugar: sugar_set) { sugar->depth = depth + 1; - dfs(sugar, terminal_sugars, depth + 1); + dfs_terminal(sugar, terminal_sugars, depth + 1); } } @@ -113,6 +113,17 @@ void Sails::Glycan::dfs_sites(Sugar *current_sugar, std::vector &site } } +void Sails::Glycan::dfs_sugars(Sugar *current_sugar, std::vector &sugars, int depth) { + const std::set &sugar_set = adjacency_list[current_sugar]; + current_sugar->depth = depth; + sugars.push_back(current_sugar); + + for (Sugar *sugar: sugar_set) { + sugar->depth = depth + 1; + dfs_sugars(sugar, sugars, depth + 1); + } +} + std::set Sails::Glycan::operator-(const Glycan& glycan) { std::set this_keys; std::transform(this->sugars.begin(), this->sugars.end(), std::inserter(this_keys, this_keys.end()), @@ -130,9 +141,24 @@ std::set Sails::Glycan::operator-(const Glycan& glycan) { std::vector Sails::Glycan::get_terminal_sugars(Glycosite &root_seq_id) { if (sugars.find(root_seq_id) == sugars.end()) { - throw std::runtime_error("Root SeqId is not valid"); + // throw std::runtime_error("Root SeqId is not valid"); + return {}; } std::vector terminal_sugars; - dfs(sugars[root_seq_id].get(), terminal_sugars); + dfs_terminal(sugars[root_seq_id].get(), terminal_sugars); + + // FUC has no links, but would be the terminal sugar in order, so add the sugar before FUC in that case + std::vector additional_sugars; + for (auto& sugar: terminal_sugars) { + gemmi::Residue* residue_ptr = Utils::get_residue_ptr_from_glycosite(sugar->site, m_structure); + if (residue_ptr->name == "FUC") { + auto previous_sugar = find_previous_sugar(sugar); + if (previous_sugar.has_value()) { + additional_sugars.emplace_back(previous_sugar.value()); + } + } + } + + terminal_sugars.insert(terminal_sugars.end(), additional_sugars.begin(), additional_sugars.end()); return terminal_sugars; } diff --git a/package/src/cpp/sails-json.cpp b/package/src/cpp/sails-json.cpp index f557cc2..bfd9a62 100644 --- a/package/src/cpp/sails-json.cpp +++ b/package/src/cpp/sails-json.cpp @@ -32,6 +32,7 @@ Sails::ResidueDatabase Sails::JSONLoader::load_residue_database() { const char *anomer_key = "anomer"; const char *wurcs_code_key = "wurcsCode"; const char *special_key = "special"; + const char *is_sugar_key = "isSugar"; ResidueDatabase database; @@ -54,7 +55,9 @@ Sails::ResidueDatabase Sails::JSONLoader::load_residue_database() { std::string wurcs_code = std::string(value[wurcs_code_key].get_string().value()); bool special = value[special_key].get_bool(); - ResidueData data = {acceptors_sets, donor_sets, snfg_shape, snfg_colour, preferred_depths, anomer, wurcs_code, special}; + bool is_sugar = value[is_sugar_key].get_bool(); + + ResidueData data = {acceptors_sets, donor_sets, snfg_shape, snfg_colour, preferred_depths, anomer, wurcs_code, special, is_sugar}; database.insert({name, data}); } @@ -150,6 +153,7 @@ void Sails::JSONWriter::write_json_file(TelemetryLog &log, std::ostream &stream) stream << "{\n"; stream << "\t\"date\": \"" << strtok(ctime(&t_c), "\n") << "\",\n"; stream << "\t\"cycles\":[\n\t\t"; + int cycle_index = 0; for (const auto &[cycle, entries]: log) { stream << "{\n"; stream << "\t\t\t\"cycle\": " << cycle << ",\n"; @@ -157,14 +161,14 @@ void Sails::JSONWriter::write_json_file(TelemetryLog &log, std::ostream &stream) for (int i = 0; i < entries.size(); ++i) { stream << "\t\t\t\t\"" << entries[i].residue_id << "\": {\"rscc\": " << entries[i].rscc_score << ", \"rsr\": " << entries[i].rsr_score << - ", \"dds\": " << entries[i].dds_score << "}"; + ", \"qscore\": " << entries[i].q_score << "}"; if (i < entries.size() - 1) { stream << ","; } stream << "\n"; } stream << "\t\t\t}\n\t\t}"; - if (cycle < log.size()) stream << ","; + if (++cycle_index < log.size()) stream << ","; } stream << "]\n}"; } diff --git a/package/src/cpp/sails-linkage.cpp b/package/src/cpp/sails-linkage.cpp index f28c97c..1c23ecd 100644 --- a/package/src/cpp/sails-linkage.cpp +++ b/package/src/cpp/sails-linkage.cpp @@ -29,6 +29,61 @@ void Sails::Model::print_successful_log(Sails::Density &density, std::optional names = { + { + "AMAN", "MAN" + } + }; + + for (int m = 0; m < structure->models.size(); m++) { + for (int c = 0; c < structure->models[m].chains.size(); c++) { + for (int r = 0; r < structure->models[m].chains[c].residues.size(); r++) { + gemmi::Residue* residue = &structure->models[m].chains[c].residues[r]; + if (names.count(residue->name) == 0) continue; + + std::string new_name = names.at(residue->name); + residue->name = new_name; + } + } + } +} + +std::set Sails::Model::get_all_glycosites() const { + std::set sites = {}; + for (auto & model : structure->models) { + for (int c = 0; c < model.chains.size(); c++) { + for (int r = 0; r < model.chains[c].residues.size(); r++) { + const gemmi::Residue* residue_ptr = &model.chains[c].residues[r]; + if (residue_database.count(residue_ptr->name) > 0) { + ResidueData residue_data = residue_database.at(residue_ptr->name); + if (!residue_data.is_sugar) continue; + Glycosite site = {0, c, r, 0}; + sites.insert(site); + } + } + } + } + return sites; +} + +void Sails::Model::remove_free_sites(std::set &all_sites) const { + std::set all_sites_in_model = get_all_glycosites(); + std::vector free_sites; + std::set_difference(all_sites_in_model.begin(), all_sites_in_model.end(), + all_sites.begin(), all_sites.end(), + std::back_inserter(free_sites)); + + std::sort(free_sites.begin(), free_sites.end(), [](const Sails::Glycosite& a, const Sails::Glycosite& b) { + return !(a < b); + }); + + for (const auto& site: free_sites) { + const auto residues = &structure->models[site.model_idx].chains[site.chain_idx].residues; + residues->erase(residues->begin() + site.residue_idx); + } +} + // UTILITY FUNCTIONS std::optional Sails::Model::get_monomer(const std::string &monomer, bool remove_h) { @@ -40,7 +95,6 @@ std::optional Sails::Model::get_monomer(const std::string &monom std::string path = monomer_library_path + "/" + char(std::tolower(monomer.front())) + "/" + monomer + ".cif"; if (!Utils::file_exists(path)) { - std::cerr << "File " << path << " does not exist" << std::endl; path = special_monomer_path + "/" + monomer + ".cif"; if (!Utils::file_exists(path)) { std::cout << path << " monomer does not exist" << std::endl; @@ -84,7 +138,7 @@ void Sails::Model::save(const std::string &path, std::vector &links) std::ofstream os(path); gemmi::cif::Document document = make_mmcif_document(*structure); gemmi::cif::Block *block = &document.sole_block(); - auto struct_conn = block->find_or_add("_struct_conn", LinkRecord::tags()); + auto struct_conn = block->find_or_add("", LinkRecord::tags()); for (LinkRecord &link: links) { struct_conn.append_row(link.labels()); @@ -186,17 +240,40 @@ void Sails::Model::remove_leaving_atom(Sails::LinkageData &data, gemmi::Residue void Sails::Model::add_sugar_to_structure(const Sugar *terminal_sugar, SuperpositionResult &favoured_addition, ChainType &chain_type) { int chain_idx = terminal_sugar->site.chain_idx; - if (chain_type == protein) { - const size_t last_chain_idx = structure->models[terminal_sugar->site.model_idx].chains.size(); - chain_idx = static_cast(last_chain_idx); - gemmi::Chain chain = gemmi::Chain(""); - chain.name = Utils::get_next_string( - structure->models[terminal_sugar->site.model_idx].chains[last_chain_idx - 1].name); - structure->models[terminal_sugar->site.model_idx].chains.emplace_back(chain); + gemmi::Model* model = &structure->models[terminal_sugar->site.model_idx]; + const std::vector* chains = &model->chains; + + if (chains->empty()) { + throw std::runtime_error("No existing chains found in the model. Is it empty?"); + } + + const auto max_it = std::max_element(chains->begin(), chains->end(), + [](const gemmi::Chain& a, const gemmi::Chain& b) { + if (a.name.length() != b.name.length()) { + return a.name.length() < b.name.length(); + } + return a.name < b.name; + }); + + auto new_chain = gemmi::Chain(""); + new_chain.name = Utils::get_next_string(max_it->name); + + model->chains.emplace_back(std::move(new_chain)); + chain_idx = static_cast(model->chains.size() - 1); + + // const size_t last_chain_idx = structure->models[terminal_sugar->site.model_idx].chains.size(); + // chain_idx = static_cast(last_chain_idx); + // gemmi::Chain chain = gemmi::Chain(""); + // chain.name = Utils::get_next_string( + // structure->models[terminal_sugar->site.model_idx].chains[last_chain_idx - 1].name); + // structure->models[terminal_sugar->site.model_idx].chains.emplace_back(chain); } - auto all_residues = &structure->models[terminal_sugar->site.model_idx].chains[chain_idx].residues; + double average_donor_bfactor = Utils::calculate_average_bfactor(terminal_sugar->site, structure); + Utils::set_all_bfactors(&favoured_addition.new_residue, average_donor_bfactor); + + const auto all_residues = &structure->models[terminal_sugar->site.model_idx].chains[chain_idx].residues; favoured_addition.new_residue.seqid = gemmi::SeqId(static_cast(all_residues->size()) + 1, '?'); all_residues->insert(all_residues->end(), std::move(favoured_addition.new_residue)); } @@ -252,11 +329,31 @@ Sails::Model::ChainType Sails::Model::find_chain_type(std::vector sugar return result ? non_protein : protein; } -double Sails::Model::calculate_clash_score(const SuperpositionResult &result) const { - constexpr double radius = 1; - gemmi::NeighborSearch ns = gemmi::NeighborSearch(structure->models[0], structure->cell, radius).populate(); +double Sails::Model::calculate_clash_score(const SuperpositionResult &result, gemmi::Atom *donor_atom) const { + return calculate_clash_score(result.new_residue, donor_atom); +} + +double Sails::Model::calculate_clash_score(const gemmi::Residue &residue, gemmi::Atom *donor_atom) const { + constexpr double radius = 1.5; + gemmi::NeighborSearch ns = gemmi::NeighborSearch(structure->models[0], structure->cell, radius); + + for (auto & model : structure->models) { + for (int c = 0; c < model.chains.size(); c++) { + for (int r = 0; r < model.chains[c].residues.size(); r++) { + const gemmi::Residue* residue_ptr = &model.chains[c].residues[r]; + gemmi::ResidueInfo residue_info = gemmi::find_tabulated_residue(residue_ptr->name); + if (residue_info.is_amino_acid() || residue_database.count(residue_ptr->name) > 0 ) { + for (int a = 0; a < model.chains[c].residues[r].atoms.size(); a++) { + gemmi::Atom* current_atom_ptr = &model.chains[c].residues[r].atoms[a]; + if (donor_atom != current_atom_ptr) ns.add_atom(*current_atom_ptr, c, r, a); + } + } + } + } + } + double clash_score = 0; - for (auto &atom: result.new_residue.atoms) { + for (auto &atom: residue.atoms) { auto nearest_atoms = ns.find_atoms(atom.pos, '\0', 0, radius); clash_score += static_cast(nearest_atoms.size()); } @@ -312,6 +409,7 @@ std::optional Sails::Model::add_residue( SuperpositionResult best_result; float best_rscc = INT_MIN; + int i = 0; for (auto &cluster: data.clusters) { std::vector torsions = cluster.torsions.get_means_in_order(); std::vector torsion_stddev = cluster.torsions.get_stddev_in_order(); @@ -337,7 +435,8 @@ std::optional Sails::Model::add_residue( } // calculate clash score - double clash_score = calculate_clash_score(result); + double clash_score = calculate_clash_score(result, atoms[2]); + // std::cout << std::endl << clash_score << std::endl; if (clash_score > 1) { continue; } @@ -413,7 +512,7 @@ std::optional Sails::Model::add_residue(gemmi::Resid SuperpositionResult result = {new_monomer, superpose_result, reference_library_monomer}; // calculate clash score - double clash_score = calculate_clash_score(result); + double clash_score = calculate_clash_score(result, &atoms[2]); if (clash_score < best_clash) { best_clash = clash_score; best_result = std::move(result); diff --git a/package/src/cpp/sails-predictions.cpp b/package/src/cpp/sails-predictions.cpp new file mode 100644 index 0000000..88f5cde --- /dev/null +++ b/package/src/cpp/sails-predictions.cpp @@ -0,0 +1,154 @@ +// +// Created by Jordan Dialpuri on 07/10/2025. +// + +#include "../include/sails-predictions.h" + + +std::optional Sails::Predictions::create_neighbour_search( + gemmi::Grid<> *grid, float threshold, const gemmi::UnitCell &unit_cell) { + + gemmi::Model model = gemmi::Model(0); + gemmi::Chain chain = gemmi::Chain("A"); + + int seqid = 0; + for (int u = 0; u < grid->nu; u++) { + for (int v = 0; v < grid->nv; v++) { + for (int w = 0; w < grid->nw; w++) { + + gemmi::Grid<>::Point point = grid->get_point(u, v, w); + if (*point.value < threshold) { + continue; + } + gemmi::Position position = grid->point_to_position(point); + gemmi::Atom atom; + atom.name = "X"; + atom.element = gemmi::Element("C"); + atom.pos = position; + gemmi::Residue residue = gemmi::Residue(); + residue.name = "PRD"; + residue.seqid = gemmi::SeqId(++seqid, '0'); + residue.atoms.emplace_back(atom); + chain.residues.emplace_back(residue); + } + } + } + + if (seqid == 0) { + return std::nullopt; + } + + model.chains = {chain}; + + gemmi::NeighborSearch ns = {model, unit_cell, 2}; + ns.populate(); + return ns; +} + +Sails::Glycosites Sails::Predictions::find_potential_sites(gemmi::Structure &structure, bool use_glycan) { + if (use_glycan && m_glycan_map != nullptr) { + return find_potential_sites_using_glycan(structure); + } + if (!use_glycan && m_protein_map != nullptr) { + return find_potential_sites_using_protein(structure); + } + throw std::invalid_argument("Glycan map is null"); +} + +Sails::Glycosites Sails::Predictions::find_potential_sites_using_glycan(gemmi::Structure &structure) { + + Glycosites potential_sites = {}; + + std::optional ns_optional = create_neighbour_search(m_glycan_map, 0.1, structure.cell); + if (!ns_optional.has_value()) { + return potential_sites; + } + gemmi::NeighborSearch ns = ns_optional.value(); + + for (int m = 0; m < structure.models.size(); m++) { + for (int c = 0; c < structure.models[m].chains.size(); c++) { + for (int r = 0; r < structure.models[m].chains[c].residues.size(); r++) { + + Glycosite site = {m, c, r}; + gemmi::Residue residue = structure.models[m].chains[c].residues[r]; + std::string residue_name = residue.name; + + if (protein_donors.find(residue_name) == protein_donors.end()) { + continue; + } + + std::vector donor_sets = m_residue_database[residue_name].donors; + + for (const auto& donor_set : donor_sets) { + std::string last_donor_atom_name = donor_set.atom3; + gemmi::Atom* last_donor_atom = residue.find_atom(last_donor_atom_name, '*'); + if (last_donor_atom == nullptr) { + continue; + } + auto nearby_points = ns.find_atoms(last_donor_atom->pos, '*', 0.1, 2); + + if (nearby_points.empty()) { + continue; + } + + potential_sites.emplace_back(site); + break; + } + } + } + } + + return potential_sites; +} + +Sails::Glycosites Sails::Predictions::find_potential_sites_using_protein(gemmi::Structure &structure) { + Glycosites potential_sites = {}; + + std::optional ns_optional = create_neighbour_search(m_protein_map, 0.1, structure.cell); + if (!ns_optional.has_value()) { + return potential_sites; + } + + std::optional glycan_ns_optional = create_neighbour_search(m_glycan_map, 0.1, structure.cell); + if (!glycan_ns_optional.has_value()) { + return potential_sites; + } + + gemmi::NeighborSearch ns = ns_optional.value(); + gemmi::NeighborSearch glycan_ns = glycan_ns_optional.value(); + + for (int m = 0; m < structure.models.size(); m++) { + for (int c = 0; c < structure.models[m].chains.size(); c++) { + for (int r = 0; r < structure.models[m].chains[c].residues.size(); r++) { + + Glycosite site = {m, c, r}; + gemmi::Residue residue = structure.models[m].chains[c].residues[r]; + std::string residue_name = residue.name; + + if (protein_donors.find(residue_name) == protein_donors.end()) { + continue; + } + + std::vector donor_sets = m_residue_database[residue_name].donors; + + for (const auto& donor_set : donor_sets) { + std::string last_donor_atom_name = donor_set.atom3; + gemmi::Atom* last_donor_atom = residue.find_atom(last_donor_atom_name, '*'); + if (last_donor_atom == nullptr) { + continue; + } + auto nearby_points = ns.find_atoms(last_donor_atom->pos, '*', 0.1, 1); + auto nearby_glycan_points = glycan_ns.find_atoms(last_donor_atom->pos, '*', 0, 2); + + if (nearby_points.empty() || nearby_glycan_points.empty()) { + continue; + } + + potential_sites.emplace_back(site); + break; + } + } + } + } + return potential_sites; +} diff --git a/package/src/cpp/sails-refine.cpp b/package/src/cpp/sails-refine.cpp index f680a4f..c7e69e6 100644 --- a/package/src/cpp/sails-refine.cpp +++ b/package/src/cpp/sails-refine.cpp @@ -7,44 +7,54 @@ double Sails::TorsionAngleRefiner::calculate_penalty(double angle, double angle_mean, double angle_stddev, double penalty_factor) { - int std_deviations_allowed = 1; - double range = std_deviations_allowed * angle_stddev; - double lower_bound = angle_mean - range; - double upper_bound = angle_mean + range; - - double deviation = 0; - if (angle < lower_bound) { - deviation = lower_bound - angle; - } else { - deviation = angle - upper_bound; - } - - double penalty = penalty_factor * pow(deviation, 2); - return penalty; + // int std_deviations_allowed = 2; + // double range = std_deviations_allowed * angle_stddev; + // double lower_bound = angle_mean - range; + // double upper_bound = angle_mean + range; + // + // double deviation = 0; + // if (angle < lower_bound) { + // deviation = lower_bound - angle; + // } else { + // deviation = angle - upper_bound; + // } + // + // double penalty = penalty_factor * pow(deviation, 2); + // return penalty; + // + double angle_r = angle * M_PI / 180.0; + double angle_mean_r = angle_mean * M_PI / 180.0; + double angle_stddev_r = angle_stddev * M_PI / 180.0; + double diff = angle_r - angle_mean_r; + double delta = atan2(sin(diff), cos(diff)) ; + double penalty = pow(delta, 2) / pow(angle_stddev_r,2); + return penalty * penalty_factor; } double Sails::TorsionAngleRefiner::calculate_penalty_factor() const { switch (m_density->get_score_method()) { case atomwise: - return 1e-3; + return 1e-2; case rscc: return 1e-5; + case q: + return 1e-5; default: return 0; } } double Sails::TorsionAngleRefiner::score_function(std::vector &all_angles) { - std::vector angles = {all_angles[0], all_angles[1], all_angles[2]}; - std::vector torsions = {all_angles[3], all_angles[4], all_angles[5]}; + std::vector angles = {all_angles[1], all_angles[2], all_angles[3]}; + std::vector torsions = {all_angles[4], all_angles[5], all_angles[6]}; gemmi::Residue residue = gemmi::Residue(m_reference_residue); - gemmi::Transform superpose_result = Model::superpose_atoms(m_all_atoms, m_reference_atoms, m_length, angles, + gemmi::Transform superpose_result = Model::superpose_atoms(m_all_atoms, m_reference_atoms, all_angles[0], angles, torsions); gemmi::transform_pos_and_adp(residue, superpose_result); SuperpositionResult result = {residue, superpose_result, m_reference_residue}; - const double score = m_density->score_result(result); + const double score = -m_density->score_result(result); double penalty = 0; double penalty_factor = calculate_penalty_factor(); @@ -53,15 +63,34 @@ double Sails::TorsionAngleRefiner::score_function(std::vector &all_angle penalty += calculate_penalty(torsions[i], m_torsion_mean[i], m_torsion_range[i], penalty_factor); } - return penalty-score; + double bond_length_delta = std::abs(all_angles[0] - m_length); + if (bond_length_delta > 0.3) { + penalty += bond_length_delta * 1e5; + } + + return score + penalty; } Sails::SuperpositionResult Sails::TorsionAngleRefiner::refine() { std::vector initial_simplex = { + m_length, m_angle_mean[0], m_angle_mean[1], m_angle_mean[2], m_torsion_mean[0], m_torsion_mean[1], m_torsion_mean[2] }; + // gemmi::Residue reference_residue = gemmi::Residue(m_reference_residue); + // gemmi::Transform reference_superpose_result = Model::superpose_atoms(m_all_atoms, m_reference_atoms, m_length, m_angle_mean, + // m_torsion_mean); + // gemmi::transform_pos_and_adp(reference_residue, reference_superpose_result); + // SuperpositionResult reference_result = {reference_residue, reference_superpose_result, m_reference_residue}; + // + // const double initial_score = m_density->score_result(reference_result); + // double penalty = 0; + // for (int i = 0; i < 3; i++) { + // penalty += calculate_penalty(m_angle_mean[i], m_angle_mean[i], m_angle_range[i], calculate_penalty_factor()); + // penalty += calculate_penalty(m_torsion_mean[i], m_torsion_mean[i], m_torsion_range[i], calculate_penalty_factor()); + // } + auto lambda = [&](std::vector &x) -> double { return this->score_function(x); }; @@ -70,24 +99,35 @@ Sails::SuperpositionResult Sails::TorsionAngleRefiner::refine() { 100000); std::vector final_angles = { - final_simplex[0], final_simplex[1], final_simplex[2] + final_simplex[1], final_simplex[2], final_simplex[3] }; std::vector final_torsions = { - final_simplex[3], final_simplex[4], final_simplex[5] + final_simplex[4], final_simplex[5], final_simplex[6] }; gemmi::Residue residue = gemmi::Residue(m_reference_residue); gemmi::Transform final_result = - Model::superpose_atoms(m_all_atoms, m_reference_atoms, m_length, final_angles, final_torsions); + Model::superpose_atoms(m_all_atoms, m_reference_atoms, final_simplex[0], final_angles, final_torsions); gemmi::transform_pos_and_adp(residue, final_result); SuperpositionResult result = {residue, final_result, m_reference_residue}; - // std::vector labels = {"alpha", "beta", "gamma", "psi", "phi", "omega"}; + // const double final_score = m_density->score_result(result); + // double final_penalty = 0; + // for (int i = 0; i < 3; i++) { + // final_penalty += calculate_penalty(final_angles[i], m_angle_mean[i], m_angle_range[i], calculate_penalty_factor()); + // final_penalty += calculate_penalty(final_angles[i], m_torsion_mean[i], m_torsion_range[i], calculate_penalty_factor()); + // } + // + // + // std::cout << std::endl << "Initial score: " << initial_score << " - penalty: " << penalty << std::endl; + // std::vector labels = {"length", "alpha", "beta", "gamma", "psi", "phi", "omega"}; // std::cout << "\nLabel\tOriginal\tNew" << std::endl; // for (int i = 0; i < final_simplex.size(); i++) { // std::cout << labels[i] << "\t" << initial_simplex[i] << "\t" << final_simplex[i] << std::endl; // } + // std::cout << "Final score: " << final_score << " - penalty: " << final_penalty << std::endl; + return result; } diff --git a/package/src/cpp/sails-score.cpp b/package/src/cpp/sails-score.cpp new file mode 100644 index 0000000..0e9509f --- /dev/null +++ b/package/src/cpp/sails-score.cpp @@ -0,0 +1,220 @@ +// +// Created by Jordan Dialpuri on 22/10/2025. +// + + +#include "../include/sails-score.h" + +#include + +#include "src/include/sails-utils.h" + +std::map Sails::Score::calculate_rsccs(Density *density, gemmi::Structure *structure, ResidueDatabase &residue_database) { + gemmi::Grid<> calculated_density = density->calculate_density_for_structure(*structure); + + constexpr double radius = 2; + auto ns = gemmi::NeighborSearch(structure->models[0], structure->cell, radius); + ns.populate(); + + gemmi::Grid<> best_grid = *density->get_best_grid(); + std::map>> residue_pairs; + + for (auto point: best_grid) { + gemmi::Position position = best_grid.point_to_position(point); + auto mark = ns.find_nearest_atom(position, radius); + if (mark == nullptr) continue; + + auto site = Glycosite(0, mark->chain_idx, mark->residue_idx, 0); + const gemmi::Residue* residue_ptr = &structure->models[site.model_idx].chains[site.chain_idx].residues[site.residue_idx]; + if (residue_database.count(residue_ptr->name) == 0) continue; + const ResidueData& residue = residue_database.at(residue_ptr->name); + if (!residue.is_sugar) continue; + + double obs = *point.value; + double calc = calculated_density.interpolate_value(position); + residue_pairs[site].emplace_back(obs, calc); + + } + std::map rsccs; + + for (const auto& [site, data]: residue_pairs) { + auto [obs_values, calc_values] = Sails::Utils::split_pairs(data); + if (obs_values.empty() || calc_values.empty()) continue; + + rsccs[site] = Sails::Density::calculate_rscc(obs_values, calc_values); + } + return rsccs; +} + +std::map Sails::Score::calculate_qscores(Sails::Density *density, gemmi::Structure *structure, + ResidueDatabase &residue_database) { + + constexpr double radius = 3; + auto ns = gemmi::NeighborSearch(structure->models[0], structure->cell, radius); + ns.populate(); + + auto [mean, stddev] = density->calculate_map_statistics(density->get_best_grid()); + + const float A = mean + (10 * stddev); + const float B = mean - stddev; + const float sigma = 0.6; + constexpr int N = 8; + + std::map qscores; + for (auto & model : structure->models) { + for (int c = 0; c < model.chains.size(); c++) { + for (int r = 0; r < model.chains[c].residues.size(); r++) { + const gemmi::Residue* residue_ptr = &model.chains[c].residues[r]; + if (residue_database.count(residue_ptr->name) > 0) { + const ResidueData& residue_data = residue_database.at(residue_ptr->name); + if (!residue_data.is_sugar) continue; + Glycosite site = {0, c, r, 0}; + + std::vector residue_q_scores = {}; + for (int a = 0; a < residue_ptr->atoms.size(); a++) { + Glycosite atom_site = {0, c, r, a}; + + double atom_q = Sails::Score::QScore::calculate_q_score(residue_ptr->atoms[a].pos, atom_site, + density->get_best_grid(), ns, A, B, sigma, N); + residue_q_scores.emplace_back(atom_q); + } + double mean_residue_q_score = std::accumulate(residue_q_scores.begin(), residue_q_scores.end(), + 0.0) / static_cast(residue_ptr->atoms.size()); + qscores[site] = mean_residue_q_score; + + } + } + } + } + + return qscores; +} + +double Sails::Score::calculate_clash_score(Sails::Glycosite &site, gemmi::Structure *structure) { + constexpr double radius = 1; + auto ns = gemmi::NeighborSearch(structure->models[0], structure->cell, radius); + ns.populate(); + + gemmi::Residue residue = Sails::Utils::get_residue_from_glycosite(site, structure); + site.atom_idx = 0; + + double clash_score = 0; + for (auto &atom: residue.atoms) { + auto nearest_atoms = ns.find_atoms(atom.pos, '\0', 0, radius); + for (const auto& nearest_atom: nearest_atoms) { + Glycosite atom_site = {0, nearest_atom->chain_idx, nearest_atom->residue_idx, 0}; + if (atom_site == site) continue; + clash_score += 1; + + } + // clash_score += static_cast(nearest_atoms.size()); + } + return clash_score; +} + +std::vector Sails::Score::QScore::fibonacci_sphere(int samples, float radius, const gemmi::Position ¢er) { + std::vector positions; + const double offset = 2.0 / samples; + const double increment = M_PI * (3.0 - sqrt(5.0)); + + for (int i = 0 ; i < samples; i++) { + const double y = ((i * offset) - 1) + (offset / 2); + const double r = sqrt(1 - pow(y,2)); + + const double phi = i * increment; + + const double x = cos(phi) * r; + const double z = sin(phi) * r; + + gemmi::Position position = {x, y, z}; + position *= radius; + position += center; + positions.emplace_back(position); + } + return positions; +} + +std::vector Sails::Score::QScore::get_radial_points(const gemmi::Position &position, float radius, int N, + Glycosite &site, gemmi::NeighborSearch &ns) { + + std::vector positions; + constexpr int max_iter = 200; + + for (int i = 0 ; i < max_iter ; i++) { + std::vector sampled_sphere = fibonacci_sphere(N+i, radius, position); + for (const auto& sampled_position: sampled_sphere) { + const gemmi::NeighborSearch::Mark* nearest_atom = ns.find_nearest_atom(sampled_position); + auto nearest_site = Glycosite(*nearest_atom); + if (nearest_site == site) { + positions.emplace_back(sampled_position); + } + + if (positions.size() >= N) { + break; + } + } + if (positions.size() >= N) { + break; + } + } + return positions; +} + +std::vector Sails::Score::QScore::sample_density(const gemmi::Grid<> *grid, std::vector &positions) { + std::vector values; + for (auto& position: positions) { + double value = grid->tricubic_interpolation(position); + values.emplace_back(value); + } + return values; +} + +double Sails::Score::QScore::calculate_q_score(const gemmi::Position & position, Glycosite &site, const gemmi::Grid<> *grid, + gemmi::NeighborSearch &ns, float A, float B, float sigma, int N) { + + const int M = 21; + std::vector sample_space(M); + for (int i = 0; i < M; i++) + sample_space[i] = (2.0f / (M - 1)) * i; + + std::vector u(N, std::vector(M, 0)); + std::vector v(N, std::vector(M, 0)); + + for (int i = 0; i < M; i++) { + const double radius = sample_space[i]; + const double gaussian_sample = A * exp(-0.5 * pow(radius / sigma, 2)) + B; + + auto radial_pts = get_radial_points(position, radius, N, site, ns); + if (radial_pts.size() != static_cast(N)) + continue; + + auto u_samples = sample_density(grid, radial_pts); + + for (int j = 0; j < N; j++) { + u[j][i] = u_samples[j]; + v[j][i] = gaussian_sample; + } + } + + std::vector u_flat, v_flat; + for (int j = 0; j < N; j++) { + const double mean_u = std::accumulate(u[j].begin(), u[j].end(), 0.0) / M; + const double mean_v = std::accumulate(v[j].begin(), v[j].end(), 0.0) / M; + for (int i = 0; i < M; i++) { + u_flat.push_back(u[j][i] - mean_u); + v_flat.push_back(v[j][i] - mean_v); + } + } + + double numerator = 0; + double sum_u2 = 0; + double sum_v2 = 0; + + for (size_t i = 0; i < u_flat.size(); i++) { + numerator += u_flat[i] * v_flat[i]; + sum_u2 += u_flat[i] * u_flat[i]; + sum_v2 += v_flat[i] * v_flat[i]; + } + + return numerator / (std::sqrt(sum_u2) * std::sqrt(sum_v2)); +} diff --git a/package/src/cpp/sails-telemetry.cpp b/package/src/cpp/sails-telemetry.cpp index aac4a29..b2de6b9 100644 --- a/package/src/cpp/sails-telemetry.cpp +++ b/package/src/cpp/sails-telemetry.cpp @@ -39,12 +39,12 @@ Sails::TelemetryLog Sails::Telemetry::calculate_log(gemmi::Structure *structure, if (residue.atoms.empty()) {continue;} const double rscc_score = density->score_residue(residue, rscc); const double rsr_score = density->score_residue(residue, rsr); - const double dds_score = density->score_residue(residue, dds); + const double q_score = density->score_residue(residue, q); log[cycle].emplace_back( Utils::format_residue_from_site(site, structure), rscc_score, rsr_score, - dds_score); + q_score); } } return log; @@ -66,3 +66,21 @@ std::optional Sails::Telemetry::format_log(gemmi::Structure *struct } return std::nullopt; } + + +std::optional Sails::Telemetry::format_log(std::vector &log, bool write, const std::string& filepath) { + JSONWriter writer; + TelemetryLog telemetry_log; + telemetry_log[1] = log; + + if (write) { + std::ofstream stream(filepath); + writer.write_json_file(telemetry_log, stream); + stream.close(); + } else { + std::stringstream stream; + writer.write_json_file(telemetry_log, stream); + return stream.str(); + } + return std::nullopt; +} diff --git a/package/src/cpp/sails-topology.cpp b/package/src/cpp/sails-topology.cpp index 07618d7..bbdc8a3 100644 --- a/package/src/cpp/sails-topology.cpp +++ b/package/src/cpp/sails-topology.cpp @@ -25,10 +25,19 @@ void Sails::Topology::find_residue_near_donor(Glycosite &glycosite, Glycan &glyc gemmi::Residue residue = Utils::get_residue_from_glycosite(glycosite, m_structure); // std::cout << "Searching near " << Utils::get_chain_from_glycosite(glycosite, m_structure).name << "-" << Utils::format_residue_key(&residue) << std::endl; - if (m_database.find(residue.name) == m_database.end()) { throw std::runtime_error("Glycosite is not in database"); } + if (m_database.find(residue.name) == m_database.end()) { + std::cout << residue.name << std::endl; + throw std::runtime_error("Glycosite is not in database"); + } auto database_entry = m_database[residue.name]; for (const auto &donor: database_entry.donors) { + // check if at least one atom, if not add the root but no further sugars + gemmi::Atom* atom = residue.find_atom(donor.atom3, '*'); + if (atom == nullptr) { + continue; + } + // get donor atoms with that name, could return > 1 with altconfs gemmi::AtomGroup donor_atoms = residue.get(donor.atom3); for (const auto &donor_atom: donor_atoms) { diff --git a/package/src/cpp/sails-utils.cpp b/package/src/cpp/sails-utils.cpp index 2c9c33d..f460694 100644 --- a/package/src/cpp/sails-utils.cpp +++ b/package/src/cpp/sails-utils.cpp @@ -81,7 +81,7 @@ std::string Sails::Utils::linkage_to_id(const Sails::LinkageData &data) { void Sails::Utils::save_residues_to_file(std::vector residues, const std::string &path) { gemmi::Structure structure; - gemmi::Model model = gemmi::Model("A"); + gemmi::Model model = gemmi::Model(0); gemmi::Chain chain = gemmi::Chain("A"); for (auto& residue : residues) { chain.residues.push_back(residue); @@ -142,3 +142,11 @@ std::vector Sails::Utils::split(const std::string &string, char del } return tokens; } + +gemmi::Model Sails::Utils::create_model(gemmi::Residue &residue) { + auto model = gemmi::Model(0); + auto chain = gemmi::Chain("A"); + chain.residues.push_back(residue); + model.chains.push_back(chain); + return model; +} diff --git a/package/src/cpp/sails-wurcs.cpp b/package/src/cpp/sails-wurcs.cpp index 3faf9a2..70ecaa9 100644 --- a/package/src/cpp/sails-wurcs.cpp +++ b/package/src/cpp/sails-wurcs.cpp @@ -238,7 +238,7 @@ std::vector Sails::WURCS::form_residue_name_order(ResidueDatabase & gemmi::Structure Sails::WURCS::generate_pseudo_structure() { gemmi::Structure pseudo_structure; - gemmi::Model pseudo_model = gemmi::Model(""); + gemmi::Model pseudo_model = gemmi::Model(0); gemmi::Chain chain = gemmi::Chain("A"); pseudo_model.chains.emplace_back(chain); diff --git a/package/src/cpp/sails.cpp b/package/src/cpp/sails.cpp index 3c1e1c7..dc982f7 100644 --- a/package/src/cpp/sails.cpp +++ b/package/src/cpp/sails.cpp @@ -26,21 +26,48 @@ #include #include +#include "src/include/sails-predictions.h" +#include "src/include/sails-score.h" -void print_rejection_dds(const Sails::Glycosite& s1, const Sails::Glycosite& s2, gemmi::Structure* structure, float score) { + +void print_rejection_dds(const Sails::Glycosite& s1, const Sails::Glycosite& s2, gemmi::Structure* structure) { std::cout << "Removing " << Sails::Utils::format_residue_from_site(s1, structure) << "--" - << Sails::Utils::format_residue_from_site(s2, structure) << " because of high DDS = " << score <site, structure); + if (clash_score > 2) { + print_removal_clash(snd->site, clash_score, structure) ; + to_remove.push_back(snd.get()); + continue; + } + } + gemmi::Residue previous_residue = Sails::Utils::get_residue_from_glycosite( sugar_result.value()->site, structure); - // if (residue.name == "ASN") { continue; } // don't remove ASN - // if (residue.name == "TRP") { continue; } // don't remove TRP + snd->site.atom_idx = 0; // set atom index to 0 so can be used in comparisons on the residue level // remove cases with low rscc - if (const float rscc = density->rscc_score(residue); rscc < rscc_threshold) { - to_remove.emplace_back(snd.get()); // add pointer to - if (debug) print_removal_rscc(residue, rscc); - continue; + if (rsccs.count(snd->site) != 0) { + const double rscc = rsccs.at(snd->site); + print_rscc(snd->site, rscc, structure); + if (rscc < rscc_threshold) { + to_remove.emplace_back(snd.get()); // add pointer to remove + if (debug) print_removal_rscc(snd->site, rscc, structure); + } + } else { + std::cout << Sails::Utils::format_site_key(fst) << " | " << Sails::Utils::format_site_key(snd->site) << std::endl; + throw std::runtime_error("Glycosite was not found in the RSCC calculation" + Sails::Utils::format_residue_from_site(snd->site, structure)); } // remove cases with high difference density score - if (const float diff_score = density->difference_density_score(residue); diff_score > dds_threshold) { - if (debug) print_rejection_dds(sugar_result.value()->site, fst, structure, diff_score); - to_remove.emplace_back(snd.get()); - } + // const int no_atoms_in_negative_density = density->check_difference_density(residue, difference_density_stats); + // // std::cout << Sails::Utils::format_residue_from_site(fst, structure) << " " << no_atoms_in_negative_density << std::endl; + // if (no_atoms_in_negative_density > 4) { + // if (debug) print_rejection_dds(sugar_result.value()->site, fst, structure); + // to_remove.emplace_back(snd.get()); + // } + // print_dds(snd->site, diff_score, structure); + // if ( diff_score > dds_threshold) { + + // } } // add linked sugars to removal list + std::set additional_sugars; for (auto &sugar: to_remove) { - std::vector additional_sugars; - for (auto &linked_sugar: glycan->adjacency_list[sugar]) { - // check that the linked sugar is not already in the removal list - if (std::find(to_remove.begin(), to_remove.end(), linked_sugar) != to_remove.end()) continue; + std::vector downstream_sugars = glycan->get_downstream_sugars(sugar); - additional_sugars.emplace_back(linked_sugar); + for (auto& downstream_sugar: downstream_sugars) { + if (std::find(to_remove.begin(), to_remove.end(), downstream_sugar) != to_remove.end()) continue; + additional_sugars.insert(downstream_sugar); } - to_remove.insert(to_remove.end(), additional_sugars.begin(), additional_sugars.end()); } + to_remove.insert(to_remove.end(), additional_sugars.begin(), additional_sugars.end()); // sort removal in decsending order so removed indices don't cause later array overflow std::sort(to_remove.begin(), to_remove.end(), [](const Sails::Sugar *a, const Sails::Sugar *b) { return !(a->site < b->site); }); + for (const auto &sugar: to_remove) { + glycan->remove_sugar(sugar); + } +} + +void remove_erroneous_sugars_em(gemmi::Structure *structure, Sails::Density *density, Sails::Glycan *glycan, float resolution, + bool debug, Sails::ResidueDatabase &residue_database) { + + std::map qscores = Sails::Score::calculate_qscores(density, structure, residue_database); + double qscore_threshold = -0.0016*pow(resolution,2) + 0.0434*pow(resolution,2)-0.3956*resolution + 1.3366; + + std::vector to_remove; + for (const auto &[fst, snd]: *glycan) { + gemmi::Residue residue = Sails::Utils::get_residue_from_glycosite(snd->site, structure); + + std::optional sugar_result = glycan->find_previous_sugar(snd.get()); + if (!sugar_result.has_value()) continue; // if there is nothing previous, it must be a protein residue + + if (residue.name == "FUC") { + double clash_score = Sails::Score::calculate_clash_score(snd->site, structure); + if (clash_score > 2) { + print_removal_clash(snd->site, clash_score, structure) ; + to_remove.push_back(snd.get()); + continue; + } + } + + gemmi::Residue previous_residue = Sails::Utils::get_residue_from_glycosite( + sugar_result.value()->site, structure); + + snd->site.atom_idx = 0; // set atom index to 0 so can be used in comparisons on the residue level + + // remove cases with low rscc + if (qscores.count(snd->site) != 0) { + const double qscore = qscores.at(snd->site); + print_qscore(snd->site, qscore, structure); + if (qscore < qscore_threshold) { + to_remove.emplace_back(snd.get()); // add pointer to remove + if (debug) print_removal_qscore(snd->site, qscore, structure); + } + } else { + std::cout << Sails::Utils::format_site_key(fst) << " | " << Sails::Utils::format_site_key(snd->site) << std::endl; + throw std::runtime_error("Glycosite was not found in the RSCC calculation" + Sails::Utils::format_residue_from_site(snd->site, structure)); + } + } + + // add linked sugars to removal list + std::set additional_sugars; for (auto &sugar: to_remove) { + std::vector downstream_sugars = glycan->get_downstream_sugars(sugar); + + for (auto& downstream_sugar: downstream_sugars) { + if (std::find(to_remove.begin(), to_remove.end(), downstream_sugar) != to_remove.end()) continue; + additional_sugars.insert(downstream_sugar); + } + } + to_remove.insert(to_remove.end(), additional_sugars.begin(), additional_sugars.end()); + + // sort removal in decsending order so removed indices don't cause later array overflow + std::sort(to_remove.begin(), to_remove.end(), [](const Sails::Sugar *a, const Sails::Sugar *b) { + return !(a->site < b->site); + }); + + for (const auto &sugar: to_remove) { glycan->remove_sugar(sugar); } } + Sails::Glycan get_glycan_topology(gemmi::Structure &structure, Sails::Glycosite &glycosite) { Sails::JSONLoader loader = {"package/data/data.json"}; Sails::ResidueDatabase residue_database = loader.load_residue_database(); @@ -122,6 +233,13 @@ Sails::Output run_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structu density.recalculate_map(structure); density.calculate_po_pc_map(original_structure); + // + // gemmi::Grid<> x = *density.get_work_grid(); + // gemmi::Ccp4<> m; + // m.grid = x ; + // m.update_ccp4_header(); + // m.write_ccp4_map("wrk.map"); + structure.cell = density.get_mtz()->cell; structure.spacegroup_hm = density.get_mtz()->spacegroup_name; @@ -130,18 +248,34 @@ Sails::Output run_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structu Sails::Telemetry telemetry = Sails::Telemetry(""); + Sails::Glycosites original_glycosites = glycosites; + for (int i = 1; i <= cycles; i++) { if (!verbose) std::cout << "\rCycle #" << i; std::cout << std::flush; if (verbose) std::cout << "\rCycle #" << i << std::endl; + if (glycosites.empty()) break; + std::set unmodellable_sites = {}; + for (auto &glycosite: glycosites) { + // auto c = Sails::Utils::get_chain_from_glycosite(glycosite, &structure); + // auto r = Sails::Utils::get_residue_from_glycosite(glycosite, &structure); + // if (c.name != "D" || r.seqid.num.value != 483) continue; + // + // std::cout << "Checking " << Sails::Utils::format_residue_from_site(glycosite, &structure) << std::endl; Sails::Glycan glycan = topology.find_glycan_topology(glycosite); // if (glycan.empty()) { continue; } // find terminal sugars Sails::Glycan new_glycan = model.extend(glycan, glycosite, density, verbose); + // if nothing was added, add site to unmodellable list + if (new_glycan.size() == glycan.size()) { + std::cout << "Nothing new modelled at site:" << Sails::Utils::format_residue_from_site(glycosite, &structure) << std::endl; + unmodellable_sites.insert(glycosite); + } + std::set differences = new_glycan - glycan; telemetry << differences; @@ -152,6 +286,13 @@ Sails::Output run_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structu density.recalculate_map(structure); density.calculate_po_pc_map(original_structure); + // const auto x = density.get_mtz(); + // std::string y = "wrk" + std::to_string(i) + ".mtz"; + // x->write_to_file(y); + // std::string z = "wrk" + std::to_string(i) + ".cif"; + // + // Sails::Utils::save_structure_to_file(structure, z); + // remove erroneous sugars for (auto &glycosite: glycosites) { Sails::Glycan glycan = topology.find_glycan_topology(glycosite); @@ -159,11 +300,18 @@ Sails::Output run_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structu // std::cout << "Attempting removal at " << Sails::Utils::format_residue_from_site(glycosite, &structure) << std::endl; Sails::Glycan old_glycan = glycan; - remove_erroneous_sugars(&structure, &density, &glycan, strict, verbose); + remove_erroneous_sugars(&structure, &density, &glycan, strict, verbose, residue_database); topology.set_structure(&structure); // need to update neighbor search after removing n residues Sails::Glycan new_glycan = topology.find_glycan_topology(glycosite); + if (new_glycan.empty()) { + unmodellable_sites.insert(glycosite); + continue; + } + + new_glycan.renumber(); + std::set differences = old_glycan - new_glycan; telemetry >> differences; @@ -172,17 +320,43 @@ Sails::Output run_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structu telemetry.save_snfg(i, glycosite_key, snfg_string); } + if (verbose && !unmodellable_sites.empty()) { + std::cout << "Stopping trials at " << unmodellable_sites.size() << " sites." << std::endl; + for (const auto& site: unmodellable_sites) { + std::cout << "\tSite:" << Sails::Utils::format_residue_from_site(site, &structure) << std::endl; + } + } + + glycosites.erase( + std::remove_if(glycosites.begin(), glycosites.end(),[&](const Sails::Glycosite &site) { + return unmodellable_sites.count(site) > 0; + }),glycosites.end() + ); + telemetry.save_state(i); } std::cout << std::endl; - // add links and write files - std::vector links = generate_link_records(&structure, &glycosites, &topology); + model.standardise_residue_names(); + + // find and remove any free sugars (likely due to something going wrong) + std::set all_sites = {}; + for (auto &glycosite: original_glycosites) { + Sails::Glycan glycan = topology.find_glycan_topology(glycosite); + auto sites = glycan.get_sites(); + all_sites.insert(sites.begin(), sites.end()); + } + + model.remove_free_sites(all_sites); + topology.set_structure(model.get_structure()); + + // add links and write files + std::vector links = generate_link_records(&structure, &original_glycosites, &topology); + Sails::add_links_to_structure(model.get_structure(), links); Sails::MTZ output_mtz = Sails::form_sails_mtz(*density.get_mtz(), "FP", "SIGFP"); std::string log_string = telemetry.format_log(&structure, &density, false).value(); - Sails::Telemetry::SNFGCycleData snfgs = telemetry.get_snfgs(); return { *model.get_structure(), @@ -192,7 +366,7 @@ Sails::Output run_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structu }; } -Sails::Output run_em_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structure, gemmi::Grid<>& grid, int cycles, +Sails::Output run_em_cycle(Sails::Glycosites &glycosites, gemmi::Structure &structure, gemmi::Grid<>& grid, float resolution, int cycles, std::string &resource_dir, bool strict, bool verbose) { @@ -206,7 +380,7 @@ Sails::Output run_em_cycle(Sails::Glycosites &glycosites, gemmi::Structure &stru Sails::Topology topology = {&structure, residue_database}; Sails::SNFG snfg = Sails::SNFG(&structure, &residue_database); - auto density = Sails::EMDensity(grid); + auto density = Sails::EMDensity(grid, resolution); structure.cell = density.get_mtz()->cell; structure.spacegroup_hm = density.get_mtz()->spacegroup_name; @@ -221,9 +395,9 @@ Sails::Output run_em_cycle(Sails::Glycosites &glycosites, gemmi::Structure &stru std::cout << std::flush; if (verbose) std::cout << "\rCycle #" << i << std::endl; + std::cout << "Attempting to model at " << glycosites.size() << " sites." << std::endl; for (auto &glycosite: glycosites) { Sails::Glycan glycan = topology.find_glycan_topology(glycosite); - // if (glycan.empty()) { continue; } // find terminal sugars Sails::Glycan new_glycan = model.extend(glycan, glycosite, density, verbose); @@ -235,17 +409,25 @@ Sails::Output run_em_cycle(Sails::Glycosites &glycosites, gemmi::Structure &stru } // remove erroneous sugars + std::set unmodellable_sites = {}; for (auto &glycosite: glycosites) { Sails::Glycan glycan = topology.find_glycan_topology(glycosite); - if (glycan.empty()) { continue; } // std::cout << "Attempting removal at " << Sails::Utils::format_residue_from_site(glycosite, &structure) << std::endl; Sails::Glycan old_glycan = glycan; - remove_erroneous_sugars(&structure, &density, &glycan, strict, verbose); + remove_erroneous_sugars_em(&structure, &density, &glycan, resolution, verbose, residue_database); topology.set_structure(&structure); // need to update neighbor search after removing n residues + Sails::Glycan new_glycan = topology.find_glycan_topology(glycosite); + if (new_glycan.empty()) { + unmodellable_sites.insert(glycosite); + continue; + } + + new_glycan.renumber(); + std::set differences = old_glycan - new_glycan; telemetry >> differences; @@ -254,14 +436,40 @@ Sails::Output run_em_cycle(Sails::Glycosites &glycosites, gemmi::Structure &stru telemetry.save_snfg(i, glycosite_key, snfg_string); } + // sort removal in decsending order so removed indices don't cause later array overflow + if (verbose && !unmodellable_sites.empty()) { + std::cout << "Stopping trials at " << unmodellable_sites.size() << " sites." << std::endl; + for (const auto& site: unmodellable_sites) { + std::cout << "\tSITE:" << Sails::Utils::format_residue_from_site(site, &structure) << std::endl; + } + } + + glycosites.erase( + std::remove_if(glycosites.begin(), glycosites.end(),[&](const Sails::Glycosite &site) { + return unmodellable_sites.count(site) > 0; + }),glycosites.end() + ); + telemetry.save_state(i); } std::cout << std::endl; + model.standardise_residue_names(); + + // find and remove any free sugars (likely due to something going wrong) + std::set all_sites = {}; + for (auto &glycosite: glycosites) { + Sails::Glycan glycan = topology.find_glycan_topology(glycosite); + auto sites = glycan.get_sites(); + all_sites.insert(sites.begin(), sites.end()); + } + + model.remove_free_sites(all_sites); + topology.set_structure(model.get_structure()); // add links and write files std::vector links = generate_link_records(&structure, &glycosites, &topology); - + Sails::add_links_to_structure(model.get_structure(), links); std::string log_string = telemetry.format_log(&structure, &density, false).value(); Sails::Telemetry::SNFGCycleData snfgs = telemetry.get_snfgs(); @@ -272,12 +480,35 @@ Sails::Output run_em_cycle(Sails::Glycosites &glycosites, gemmi::Structure &stru }; } +Sails::Glycosites identify_predicted_sites(gemmi::Structure &structure, gemmi::Grid<>& glycan_grid, std::string &resource_dir) { + std::string data_file = resource_dir + "/data.json"; + Sails::JSONLoader loader = {data_file}; + Sails::ResidueDatabase residue_database = loader.load_residue_database(); + Sails::LinkageDatabase linkage_database = loader.load_linkage_database(); + auto predictions = Sails::Predictions(&glycan_grid, linkage_database, residue_database); + + Sails::Glycosites potential_sites = predictions.find_potential_sites(structure, true); + return potential_sites; +} + +Sails::Glycosites identify_predicted_sites(gemmi::Structure &structure, gemmi::Grid<>& glycan_grid, gemmi::Grid<>& protein_grid, bool use_glycan, std::string &resource_dir) { + std::string data_file = resource_dir + "/data.json"; + Sails::JSONLoader loader = {data_file}; + Sails::ResidueDatabase residue_database = loader.load_residue_database(); + Sails::LinkageDatabase linkage_database = loader.load_linkage_database(); + auto predictions = Sails::Predictions(&glycan_grid, &protein_grid, linkage_database, residue_database); + + Sails::Glycosites potential_sites = predictions.find_potential_sites(structure, use_glycan); + return potential_sites; +} + + // XRAY FUNCTIONS Sails::Output n_glycosylate(gemmi::Structure &structure, Sails::MTZ &sails_mtz, int cycles, std::string &resource_dir, bool verbose) { auto glycosites = Sails::find_n_glycosylation_sites(structure); - return run_cycle(glycosites, structure, sails_mtz, cycles, resource_dir, false, verbose); + return run_cycle(glycosites, structure, sails_mtz, cycles, resource_dir, true, verbose); } Sails::Output c_glycosylate(gemmi::Structure &structure, Sails::MTZ &sails_mtz, int cycles, std::string &resource_dir, @@ -294,28 +525,111 @@ Sails::Output o_mannosylate(gemmi::Structure &structure, Sails::MTZ &sails_mtz, return run_cycle(glycosites, structure, sails_mtz, cycles, resource_dir, true, verbose); } +Sails::Output auto_glycosylate(gemmi::Structure &structure, Sails::MTZ &sails_mtz, gemmi::Grid<>& glycan_grid, gemmi::Grid<>& protein_grid, int cycles, std::string &resource_dir, + bool verbose) { + Sails::Glycosites predicted_glycosites = identify_predicted_sites(structure, glycan_grid, protein_grid, false, resource_dir); + std::cout << "Found " << predicted_glycosites.size() << " potential sites using deep learning models" << std::endl; + Sails::Glycosites n_glycosites = Sails::find_n_glycosylation_sites(structure); + Sails::Glycosites c_glycosites = Sails::find_c_glycosylation_sites(structure); + + std::set glycosites_set = {predicted_glycosites.begin(), predicted_glycosites.end()}; + glycosites_set.insert(n_glycosites.begin(), n_glycosites.end()); + glycosites_set.insert(c_glycosites.begin(), c_glycosites.end()); + Sails::Glycosites glycosites = {glycosites_set.begin(), glycosites_set.end()}; + int diff = static_cast(glycosites.size()) - static_cast(predicted_glycosites.size()); + std::cout << "Supplemented with " << diff << " sites from the sequence" << std::endl; + + // prefer to glycosylate N first, then C, then O. + std::sort(glycosites.begin(), glycosites.end(), + [&](const Sails::Glycosite& a, const Sails::Glycosite& b) { + auto rank = [&](const Sails::Glycosite& s) { + gemmi::Residue* residue = Sails::Utils::get_residue_ptr_from_glycosite(s, &structure); + if (residue->name == "ASN") return 0; + if (residue->name == "TRP") return 1; + if (residue->name == "SER" || residue->name == "THR") return 2; + return 3; + }; + return rank(a) < rank(b); + }); + + + return run_cycle(glycosites, structure, sails_mtz, cycles, resource_dir, false, verbose); +} + +Sails::Output glycosylate_site(gemmi::Structure &structure, Sails::MTZ &sails_mtz, std::string& chain, int seqid, int cycles, std::string &resource_dir, + bool verbose) { + std::optional potential_site = Sails::find_site(structure, chain, seqid); + if (!potential_site.has_value()) { + throw std::runtime_error("Site could not be found"); + } + Sails::Glycosites glycosites = {potential_site.value()}; + return run_cycle(glycosites, structure, sails_mtz, cycles, resource_dir, false, verbose); +} + + // EM FUNCTIONS -Sails::Output n_glycosylate(gemmi::Structure &structure, gemmi::Grid<>& grid, int cycles, std::string &resource_dir, +Sails::Output n_glycosylate(gemmi::Structure &structure, gemmi::Grid<>& grid, float resolution, int cycles, std::string &resource_dir, bool verbose) { auto glycosites = Sails::find_n_glycosylation_sites(structure); - return run_em_cycle(glycosites, structure, grid, cycles, resource_dir, false, verbose); + return run_em_cycle(glycosites, structure, grid, resolution, cycles, resource_dir, false, verbose); } -Sails::Output c_glycosylate(gemmi::Structure &structure, gemmi::Grid<>& grid, int cycles, std::string &resource_dir, +Sails::Output c_glycosylate(gemmi::Structure &structure, gemmi::Grid<>& grid, float resolution, int cycles, std::string &resource_dir, bool verbose) { auto glycosites = Sails::find_c_glycosylation_sites(structure); - return run_em_cycle(glycosites, structure, grid, cycles, resource_dir, false, verbose); + return run_em_cycle(glycosites, structure, grid, resolution, cycles, resource_dir, false, verbose); } -Sails::Output o_mannosylate(gemmi::Structure &structure, gemmi::Grid<>& grid, int cycles, std::string &resource_dir, +Sails::Output o_mannosylate(gemmi::Structure &structure, gemmi::Grid<>& grid, float resolution, int cycles, std::string &resource_dir, bool verbose) { Sails::SolventAccessibility sa = Sails::SolventAccessibility(&structure); Sails::SolventAccessibility::SolventAccessibilityMap sa_map = sa.calculate_solvent_accessibility(); auto glycosites = Sails::find_o_mannosylation_sites(structure, sa_map); - return run_em_cycle(glycosites, structure, grid, cycles, resource_dir, true, verbose); + return run_em_cycle(glycosites, structure, grid, resolution, cycles, resource_dir, true, verbose); +} + +Sails::Output auto_glycosylate(gemmi::Structure &structure, gemmi::Grid<>& grid, float resolution, gemmi::Grid<>& glycan_grid, gemmi::Grid<>& protein_grid, int cycles, std::string &resource_dir, + bool verbose) { + Sails::Glycosites predicted_glycosites = identify_predicted_sites(structure, glycan_grid, protein_grid, false, resource_dir); + std::cout << "Found " << predicted_glycosites.size() << " potential sites using deep learning models" << std::endl; + Sails::Glycosites n_glycosites = Sails::find_n_glycosylation_sites(structure); + Sails::Glycosites c_glycosites = Sails::find_c_glycosylation_sites(structure); + + std::set glycosites_set = {predicted_glycosites.begin(), predicted_glycosites.end()}; + glycosites_set.insert(n_glycosites.begin(), n_glycosites.end()); + glycosites_set.insert(c_glycosites.begin(), c_glycosites.end()); + Sails::Glycosites glycosites = {glycosites_set.begin(), glycosites_set.end()}; + int diff = static_cast(glycosites.size()) - static_cast(predicted_glycosites.size()); + std::cout << "Supplemented with " << diff << " sites from the sequence" << std::endl; + + // prefer to glycosylate N first, then C, then O. + std::sort(glycosites.begin(), glycosites.end(), + [&](const Sails::Glycosite& a, const Sails::Glycosite& b) { + auto rank = [&](const Sails::Glycosite& s) { + gemmi::Residue* residue = Sails::Utils::get_residue_ptr_from_glycosite(s, &structure); + if (residue->name == "ASN") return 0; + if (residue->name == "TRP") return 1; + if (residue->name == "SER" || residue->name == "THR") return 2; + return 3; + }; + return rank(a) < rank(b); + }); + + return run_em_cycle(glycosites, structure, grid, resolution, cycles, resource_dir, false, verbose); } +Sails::Output glycosylate_site(gemmi::Structure &structure, gemmi::Grid<>& grid, float resolution, std::string& chain, int seqid, int cycles, std::string &resource_dir, + bool verbose) { + std::optional potential_site = Sails::find_site(structure, chain, seqid); + if (!potential_site.has_value()) { + throw std::runtime_error("Site could not be found"); + } + Sails::Glycosites glycosites = {potential_site.value()}; + return run_em_cycle(glycosites, structure, grid, resolution, cycles, resource_dir, false, verbose); +} + + //SNFG FUNCTIONS @@ -462,6 +776,212 @@ gemmi::Structure morph(gemmi::Structure& structure, std::string& wurcs, std::str } +Sails::Output validate(gemmi::Structure& structure, Sails::MTZ &sails_mtz, bool remove, float threshold, std::string& resource_dir) { + std::string data_file = resource_dir + "/data.json"; + Sails::JSONLoader loader = {data_file}; + Sails::ResidueDatabase residue_database = loader.load_residue_database(); + Sails::LinkageDatabase linkage_database = loader.load_linkage_database(); + + gemmi::Mtz mtz = form_gemmi_mtz(sails_mtz); + check_spacegroup(&mtz, &structure); // check to ensure the MTZ has a spacegroup + + auto density = Sails::XtalDensity(mtz); + density.load_map_coefficients(); + + std::map rsccs = Sails::Score::calculate_rsccs(&density, &structure, residue_database); + + std::vector to_remove = {}; + std::vector log = {}; + + for (auto& [site, rscc]: rsccs) { + std::string residue_key = Sails::Utils::format_residue_from_site(site, &structure); + log.emplace_back(residue_key, rscc); + if (rscc > threshold) { + continue; + } + to_remove.emplace_back(site); + } + + if (remove) { + Sails::Topology topology = {&structure, residue_database}; + + std::set removal_set = {to_remove.begin(), to_remove.end()}; + + for (auto &site: to_remove) { + auto glycan = topology.find_glycan_topology(site); + std::vector downstream_sugars = glycan.get_downstream_sugars(site); + for (auto& downstream_sugar: downstream_sugars) { + if (std::find(removal_set.begin(), removal_set.end(), downstream_sugar->site) != removal_set.end()) continue; + downstream_sugar->site.atom_idx = 0; // remove atom site from site to allow sorting + removal_set.insert(downstream_sugar->site); + } + } + + std::vector removal_list = {removal_set.begin(), removal_set.end()}; + + std::sort(removal_list.begin(), removal_list.end(), [](const Sails::Glycosite& a, const Sails::Glycosite& b) { + return !(a < b); + }); + + for (auto &site: removal_list) { + const auto residue_ptr = &structure.models[site.model_idx].chains[site.chain_idx].residues; + residue_ptr->erase(residue_ptr->begin() + site.residue_idx); + } + } + + + std::string log_string = Sails::Telemetry::format_log(log, false, "").value(); + return { + structure, + log_string + }; +} + +Sails::Output validate_site(gemmi::Structure& structure, Sails::MTZ &sails_mtz, std::string& chain, int seqid, bool remove, float threshold, std::string& resource_dir) { + std::string data_file = resource_dir + "/data.json"; + Sails::JSONLoader loader = {data_file}; + Sails::ResidueDatabase residue_database = loader.load_residue_database(); + Sails::LinkageDatabase linkage_database = loader.load_linkage_database(); + + gemmi::Mtz mtz = form_gemmi_mtz(sails_mtz); + check_spacegroup(&mtz, &structure); // check to ensure the MTZ has a spacegroup + + auto density = Sails::XtalDensity(mtz); + density.load_map_coefficients(); + + std::vector to_remove = {}; + std::vector log = {}; + + std::optional potential_site = Sails::find_site(structure, chain, seqid); + if (!potential_site.has_value()) { + throw std::runtime_error("Could not find potential site"); + } + std::cout << "Validating glycans at " << Sails::Utils::format_residue_from_site(potential_site.value(), &structure) << std::endl; + + Sails::Topology topology = {&structure, residue_database}; + auto glycan = topology.find_glycan_topology(potential_site.value()); + auto glycan_sites = glycan.get_sites(); + + std::cout << "Found " << glycan_sites.size() << " sites" << std::endl; + + std::map rsccs = Sails::Score::calculate_rsccs(&density, &structure, residue_database); + + for (auto& [site, rscc]: rsccs) { + if (std::find(glycan_sites.begin(), glycan_sites.end(), site) == glycan_sites.end()) continue; + + std::string residue_key = Sails::Utils::format_residue_from_site(site, &structure); + log.emplace_back(residue_key, rscc); + if (rscc > threshold) { + continue; + } + std::cout << "Scheduling " << Sails::Utils::format_residue_from_site(site, &structure) << " for removal because RSCC " << rscc << "<" << threshold << std::endl; + to_remove.emplace_back(site); + } + + if (remove) { + Sails::Topology topology = {&structure, residue_database}; + + std::set removal_set = {to_remove.begin(), to_remove.end()}; + + for (auto &site: to_remove) { + auto glycan = topology.find_glycan_topology(site); + std::vector downstream_sugars = glycan.get_downstream_sugars(site); + for (auto& downstream_sugar: downstream_sugars) { + if (std::find(removal_set.begin(), removal_set.end(), downstream_sugar->site) != removal_set.end()) continue; + downstream_sugar->site.atom_idx = 0; // remove atom site from site to allow sorting + removal_set.insert(downstream_sugar->site); + } + } + + std::vector removal_list = {removal_set.begin(), removal_set.end()}; + + std::sort(removal_list.begin(), removal_list.end(), [](const Sails::Glycosite& a, const Sails::Glycosite& b) { + return !(a < b); + }); + + for (auto &site: removal_list) { + const auto residue_ptr = &structure.models[site.model_idx].chains[site.chain_idx].residues; + residue_ptr->erase(residue_ptr->begin() + site.residue_idx); + } + } + + + std::string log_string = Sails::Telemetry::format_log(log, false, "").value(); + return { + structure, + log_string + }; +} + +Sails::Output validate(gemmi::Structure& structure, gemmi::Grid<>& grid, float resolution, bool remove, float threshold, bool use_q, std::string& resource_dir) { + std::string data_file = resource_dir + "/data.json"; + Sails::JSONLoader loader = {data_file}; + Sails::ResidueDatabase residue_database = loader.load_residue_database(); + Sails::LinkageDatabase linkage_database = loader.load_linkage_database(); + + auto density = Sails::EMDensity(grid, resolution); + + + std::map rsccs = Sails::Score::calculate_rsccs(&density, &structure, residue_database); + std::map qscores = Sails::Score::calculate_qscores(&density, &structure, residue_database); + std::map scores = use_q ? qscores : rsccs; + + // equation from https://doi.org/10.1107/S2059798325005923 + double q_score_threshold = -0.0016*pow(resolution,2) + 0.0434*pow(resolution,2)-0.3956*resolution + 1.3366; + + double applied_threshold = use_q ? q_score_threshold : threshold ; + + if (remove) { + std::cout << "Enforcing score limit of " << applied_threshold << std::endl; + } + std::vector to_remove = {}; + std::vector log = {}; + + for (auto& [site, score]: scores) { + std::string residue_key = Sails::Utils::format_residue_from_site(site, &structure); + log.emplace_back(residue_key, rsccs.at(site), qscores.at(site)); + if (score > applied_threshold) { + continue; + } + to_remove.emplace_back(site); + } + + if (remove) { + Sails::Topology topology = {&structure, residue_database}; + + std::set removal_set = {to_remove.begin(), to_remove.end()}; + + for (auto &site: to_remove) { + auto glycan = topology.find_glycan_topology(site); + std::vector downstream_sugars = glycan.get_downstream_sugars(site); + for (auto& downstream_sugar: downstream_sugars) { + if (removal_set.count(downstream_sugar->site) > 0) continue; + downstream_sugar->site.atom_idx = 0; // remove atom site from site to allow sorting + removal_set.insert(downstream_sugar->site); + } + } + + std::vector removal_list = {removal_set.begin(), removal_set.end()}; + + std::sort(removal_list.begin(), removal_list.end(), [](const Sails::Glycosite& a, const Sails::Glycosite& b) { + return !(a < b); + }); + + for (auto &site: removal_list) { + const auto residue_ptr = &structure.models[site.model_idx].chains[site.chain_idx].residues; + residue_ptr->erase(residue_ptr->begin() + site.residue_idx); + } + } + + + std::string log_string = Sails::Telemetry::format_log(log, false, "").value(); + return { + structure, + log_string + }; +} + + // gemmi::Structure wurcs(gemmi::Structure& structure, std::string chain, int seqid, std::string& resource_dir) { // std::string data_file = resource_dir + "/data.json"; // Sails::JSONLoader loader = {data_file}; @@ -522,5 +1042,5 @@ int main() { std::string data_file = "package/src/sails/data/"; auto glycosites = Sails::find_n_glycosylation_sites(structure); - run_em_cycle(glycosites, structure, map.grid, 1, data_file, false, true); + // run_em_cycle(glycosites, structure, map.grid, 1, data_file, false, true); } diff --git a/package/src/include/density/sails-density.h b/package/src/include/density/sails-density.h index 3568fda..6ad907e 100644 --- a/package/src/include/density/sails-density.h +++ b/package/src/include/density/sails-density.h @@ -25,7 +25,7 @@ namespace Sails { typedef clipper::HKL_info::HKL_reference_index HRI; enum DensityScoreMethod { - atomwise, rscc, rsr, dds + atomwise, rscc, rsr, q }; class Density { @@ -46,6 +46,48 @@ namespace Sails { [[nodiscard]] virtual const DensityScoreMethod get_score_method() const = 0; + [[nodiscard]] virtual std::pair get_map_stats() = 0; + + /** + * @brief Calculates the density for a given box based on a gemmi::Residue object. + * + * This method takes a gemmi::Residue object and calculates the density for the specified box + * using the gemmi::DensityCalculator class. The density calculation is performed using the + * density score method specified in the constructor of the gemmi::DensityCalculator. + * + * @param residue The gemmi::Residue object for which the density is calculated. + * @param box + * + * @return The calculated density grid for the specified box. + */ + virtual gemmi::Grid<> calculate_density_for_box(gemmi::Residue &residue, gemmi::Box &box) const = 0; + + /** + * @brief Calculates the density for a gemmi::Residue object. + * + * This method takes a gemmi::Residue object and calculates the density + * using the gemmi::DensityCalculator class. The density calculation is performed using the + * density score method specified in the constructor of the gemmi::DensityCalculator. + * + * @param residue The gemmi::Residue object for which the density is calculated. + * + * @return The calculated density grid for the specified box. + */ + virtual gemmi::Grid<> calculate_density_for_grid(gemmi::Residue &residue) const = 0; + + /** + * @brief Calculates the density for a gemmi::Residue object. + * + * This method takes a gemmi::Residue object and calculates the density + * using the gemmi::DensityCalculator class. The density calculation is performed using the + * density score method specified in the constructor of the gemmi::DensityCalculator. + * + * @param residue The gemmi::Residue object for which the density is calculated. + * + * @return The calculated density grid for the specified box. + */ + virtual gemmi::Grid<> calculate_density_for_structure(gemmi::Structure &structure) const = 0; + /** * @brief Scores a residue based on the specified density score method. * @@ -83,32 +125,7 @@ namespace Sails { */ [[nodiscard]] float atomwise_score(const gemmi::Residue &residue) const; - /** - * @brief Calculates the density for a given box based on a gemmi::Residue object. - * - * This method takes a gemmi::Residue object and calculates the density for the specified box - * using the gemmi::DensityCalculator class. The density calculation is performed using the - * density score method specified in the constructor of the gemmi::DensityCalculator. - * - * @param residue The gemmi::Residue object for which the density is calculated. - * @param box - * - * @return The calculated density grid for the specified box. - */ - gemmi::Grid<> calculate_density_for_box(gemmi::Residue &residue, gemmi::Box &box) const; - /** - * @brief Calculates the density for a gemmi::Residue object. - * - * This method takes a gemmi::Residue object and calculates the density - * using the gemmi::DensityCalculator class. The density calculation is performed using the - * density score method specified in the constructor of the gemmi::DensityCalculator. - * - * @param residue The gemmi::Residue object for which the density is calculated. - * - * @return The calculated density grid for the specified box. - */ - gemmi::Grid<> calculate_density_for_grid(gemmi::Residue &residue) const; /** * @brief Calculates the RSCC (Real Space Correlation Coefficient) score for a given residue. @@ -135,7 +152,8 @@ namespace Sails { * * @return The RSCC between the observed and calculated values. */ - static float calculate_rscc(std::vector obs_values, std::vector calc_values) ; + template + static T calculate_rscc(std::vector obs_values, std::vector calc_values) ; /** * @brief Calculates the RSCC score for a given superposition result. @@ -197,10 +215,11 @@ namespace Sails { * This method calculates the difference density score for the given residue using the difference_grid. * * @param residue The gemmi::Residue object for which the difference density score is to be calculated. + * @param map_stats * * @return The difference density score for the residue. */ - float difference_density_score(gemmi::Residue &residue) const; + int check_difference_density(gemmi::Residue &residue, std::pair map_stats) const; /** * @brief Scores an atom @@ -224,6 +243,11 @@ namespace Sails { */ [[nodiscard]] float score_position(const gemmi::Position& pos) const; + + [[nodiscard]] std::pair calculate_map_statistics(const gemmi::Grid<> *grid) const; + + [[nodiscard]] double q_score(gemmi::Residue &residue); + }; } // namespace Sails diff --git a/package/src/include/density/sails-em-density.h b/package/src/include/density/sails-em-density.h index d51e1ce..998f26e 100644 --- a/package/src/include/density/sails-em-density.h +++ b/package/src/include/density/sails-em-density.h @@ -5,9 +5,9 @@ #include "sails-density.h" namespace Sails { - class EMDensity : public Density { + class EMDensity : public Density{ public: - explicit EMDensity(gemmi::Grid<> &grid); + explicit EMDensity(gemmi::Grid<> &grid, float resolution); [[nodiscard]] const gemmi::Mtz *get_mtz() const override { return &m_mtz; } @@ -17,7 +17,7 @@ namespace Sails { [[nodiscard]] const gemmi::Grid<> *get_difference_grid() const override { return &m_grid; } - [[nodiscard]] const double get_resolution() const override { return 2.0; } + [[nodiscard]] const double get_resolution() const override { return m_resolution; } [[nodiscard]] const DensityScoreMethod get_score_method() const override { return score_method; } @@ -25,6 +25,23 @@ namespace Sails { return &calculated_maps; } + [[nodiscard]] std::pair get_map_stats() override { + if (map_mean == INT_MIN || map_stddev == INT_MIN ) { + auto [mean, stddev] = calculate_map_statistics(get_work_grid()); + map_mean = mean; + map_stddev = stddev; + return std::make_pair(map_mean, map_stddev); + } + return std::make_pair(map_mean, map_stddev); + } + + gemmi::Grid<> calculate_density_for_box(gemmi::Residue &residue, gemmi::Box &box) const override; + + gemmi::Grid<> calculate_density_for_grid(gemmi::Residue &residue) const override; + + gemmi::Grid<> calculate_density_for_structure(gemmi::Structure &structure) const override; + + private: /** * Best map @@ -46,6 +63,8 @@ namespace Sails { */ gemmi::Mtz m_mtz; + float m_resolution; + /** * Fc maps for residues in standard positions - used for fast RSCC calculations */ @@ -56,6 +75,9 @@ namespace Sails { * * The DensityScoreMethod class is used to represent the score method for scoring residues to density */ - DensityScoreMethod score_method = rscc; + DensityScoreMethod score_method = atomwise; + + float map_mean = INT_MIN; + float map_stddev = INT_MIN; }; } diff --git a/package/src/include/density/sails-xtal-density.h b/package/src/include/density/sails-xtal-density.h index 2811299..60d7d52 100644 --- a/package/src/include/density/sails-xtal-density.h +++ b/package/src/include/density/sails-xtal-density.h @@ -14,6 +14,8 @@ namespace Sails { explicit XtalDensity(gemmi::Mtz &mtz, const std::string& F, const std::string& SIGF); + void load_map_coefficients(const std::string& fwt = "FWT", const std::string& phwt = "PHWT"); + [[nodiscard]] const gemmi::Mtz *get_mtz() const override { return &m_mtz; } [[nodiscard]] const gemmi::Grid<> *get_work_grid() const override { return &m_po_pc_grid; } @@ -22,7 +24,7 @@ namespace Sails { [[nodiscard]] const gemmi::Grid<> *get_difference_grid() const override { return &m_difference_grid; } - [[nodiscard]] const double get_resolution() const override { return 2.0; } + [[nodiscard]] const double get_resolution() const override { return m_mtz.resolution_high(); } [[nodiscard]] const DensityScoreMethod get_score_method() const override { return score_method;} @@ -30,6 +32,22 @@ namespace Sails { return &calculated_maps; } + [[nodiscard]] std::pair get_map_stats() override { + if (map_mean == INT_MIN || map_stddev == INT_MIN ) { + auto [mean, stddev] = calculate_map_statistics(get_best_grid()); + map_mean = mean; + map_stddev = stddev; + return std::make_pair(map_mean, map_stddev); + } + return std::make_pair(map_mean, map_stddev); + } + + gemmi::Grid<> calculate_density_for_box(gemmi::Residue &residue, gemmi::Box &box) const override; + + gemmi::Grid<> calculate_density_for_grid(gemmi::Residue &residue) const override; + + gemmi::Grid<> calculate_density_for_structure(gemmi::Structure &structure) const override; + /** * @brief Recalculates the map based on the given structure. * @@ -165,5 +183,9 @@ namespace Sails { * Clipper best map */ clipper::Xmap m_best_map; + + float map_mean = INT_MIN; + float map_stddev = INT_MIN; + }; } diff --git a/package/src/include/sails-cif.h b/package/src/include/sails-cif.h index c849b15..366bc82 100644 --- a/package/src/include/sails-cif.h +++ b/package/src/include/sails-cif.h @@ -88,8 +88,8 @@ namespace Sails { */ std::vector labels() { double distance = (atom1.pos - atom2.pos).length(); - std::string res1_seqid = residue1.seqid.str(); - std::string res2_seqid = residue2.seqid.str(); + std::string res1_seqid = residue1.seqid.num.str(); + std::string res2_seqid = residue2.seqid.num.str(); return { id, @@ -130,7 +130,6 @@ namespace Sails { }; } - private: gemmi::Chain chain1; gemmi::Chain chain2; gemmi::Residue residue1; @@ -153,6 +152,9 @@ namespace Sails { std::vector generate_link_records(gemmi::Structure *structure, Sails::Glycosites *glycosites, Sails::Topology *topology); + + void add_links_to_structure(gemmi::Structure *structure, std::vector& link_records); + } #endif //SAILS_CIF_H diff --git a/package/src/include/sails-gemmi-bindings.h b/package/src/include/sails-gemmi-bindings.h index e7c6c84..5263f2f 100644 --- a/package/src/include/sails-gemmi-bindings.h +++ b/package/src/include/sails-gemmi-bindings.h @@ -152,12 +152,14 @@ namespace Sails { * log string. */ struct Output { - Output(gemmi::Structure& structure, MTZ& mtz, std::string log, std::map>& snfgs): + Output(gemmi::Structure& structure, MTZ& mtz, std::string& log, std::map>& snfgs): structure(structure), mtz(mtz), log(std::move(log)), snfgs(snfgs){}; - Output(gemmi::Structure& structure, std::string log, std::map>& snfgs): + Output(gemmi::Structure& structure, std::string& log, std::map>& snfgs): structure(structure), log(std::move(log)), snfgs(snfgs){}; + Output(gemmi::Structure& structure, std::string& log): structure(structure), log(std::move(log)) {}; + gemmi::Structure structure ; MTZ mtz{}; std::string log; diff --git a/package/src/include/sails-glycan.h b/package/src/include/sails-glycan.h index 76df3bd..76c60e9 100644 --- a/package/src/include/sails-glycan.h +++ b/package/src/include/sails-glycan.h @@ -244,6 +244,26 @@ namespace Sails { return sites; } + /** + * @brief Returns the DFS order of the sugars sites. + * + * @return A vector of Glycosites in DFS order + */ + [[nodiscard]] std::vector get_sugar_site_dfs_order_without_root() { + std::vector sites; + dfs_sites(root_sugar, sites, 0); + sites.erase(sites.begin()); + return sites; + } + + void renumber() { + std::vector sites = get_sugar_site_dfs_order_without_root(); + for (int i = 0; i < sites.size(); i++) { + gemmi::Residue* residue_ptr = Utils::get_residue_ptr_from_glycosite(sites[i], m_structure); + residue_ptr->seqid.num.value = i+1; + } + } + /** * @brief Returns the order of the sugars. * @@ -298,6 +318,25 @@ namespace Sails { return count-1; } + + [[nodiscard]] std::vector get_downstream_sugars(Sugar* sugar) { + std::vector downstream_sugars; + dfs_sugars(sugar, downstream_sugars, 0); + // downstream_sugars.erase(downstream_sugars.begin()); + return downstream_sugars; + } + + [[nodiscard]] std::vector get_downstream_sugars(Glycosite& site) { + std::vector downstream_sugars; + if (sugars.count(site) == 0) { + return {}; + } + Sugar* sugar = sugars.at(site).get(); + dfs_sugars(sugar, downstream_sugars, 0); + // downstream_sugars.erase(downstream_sugars.begin()); + return downstream_sugars; + } + /** * @brief Returns internal adjacency list. * @@ -325,6 +364,24 @@ namespace Sails { return &sugars; } + + /** + * @brief Returns the sites in this glycan. + * + * @return A ptr to all sugars in this glycan. + */ + [[nodiscard]] std::vector get_sites() const { + std::vector sites; + sites.reserve(sugars.size()); + for(const auto&[fst, snd]: sugars) { + Glycosite site = fst; + site.atom_idx = 0; // set to 0 for later comparisons + sites.emplace_back(site); + } + return sites; + } + + /** * @brief Adds linkage between two sugars. * @@ -541,7 +598,7 @@ namespace Sails { * @param terminal_sugars - A vector to store the terminal sugar molecules found. * @param depth - The depth of the current search */ - [[maybe_unused]] void dfs(Sugar *current_sugar, std::vector &terminal_sugars, int depth); + [[maybe_unused]] void dfs_terminal(Sugar *current_sugar, std::vector &terminal_sugars, int depth); /** * Performs a depth-first search (DFS) on a graph of sugar molecules, starting from @@ -554,6 +611,16 @@ namespace Sails { [[maybe_unused]] void dfs_sites(Sugar *current_sugar, std::vector &sites, int depth); + /** + * Performs a depth-first search (DFS) on a graph of sugar molecules, starting from + * a given sugar and collecting terminal sugars. + * + * @param current_sugar - The current sugar molecule being visited. + * @param sites - A vector to store the sites + * @param depth - The depth of the current search + */ + [[maybe_unused]] void dfs_sugars(Sugar *current_sugar, std::vector &sugars, int depth); + /** * @brief Get the structure associated with the glycan. * diff --git a/package/src/include/sails-linkage.h b/package/src/include/sails-linkage.h index 91e16f5..636371f 100644 --- a/package/src/include/sails-linkage.h +++ b/package/src/include/sails-linkage.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -142,6 +143,13 @@ namespace Sails { static gemmi::Residue replace_residue(gemmi::Residue *target_residue, const std::string &replacement_residue_name); + + + + void standardise_residue_names() const; + + void remove_free_sites(std::set& all_sites) const; + private: typedef std::map > PossibleAdditions; @@ -217,9 +225,21 @@ namespace Sails { * SuperpositionResult. Nearby atoms are found using a NeighborSearch with a given radius. * * @param result The SuperpositionResult from which to calculate the clash score. + * @param donor_atom + * @return The calculated clash score. + */ + [[nodiscard]] double calculate_clash_score(const SuperpositionResult &result, gemmi::Atom *donor_atom) const; + + /** @brief Calculates the clash score for the given SuperpositionResult. + * + * The clash score is calculated by finding the number of nearby atoms for each atom in the + * SuperpositionResult. Nearby atoms are found using a NeighborSearch with a given radius. + * + * @param result The SuperpositionResult from which to calculate the clash score. + * @param donor_atom * @return The calculated clash score. */ - [[nodiscard]] double calculate_clash_score(const SuperpositionResult &result) const; + [[nodiscard]] double calculate_clash_score(const gemmi::Residue &residue, gemmi::Atom *donor_atom) const; /** @@ -252,6 +272,10 @@ namespace Sails { static void move_acceptor_atomic_positions(std::vector &atoms, double length, std::vector &angles, std::vector &torsions); + + [[nodiscard]] std::set get_all_glycosites() const; + + // /** // * @brief Move the positions of acceptor atoms based on given parameters. // * diff --git a/package/src/include/sails-model.h b/package/src/include/sails-model.h index 9260b81..ece9b50 100644 --- a/package/src/include/sails-model.h +++ b/package/src/include/sails-model.h @@ -7,7 +7,8 @@ #include #include - +#include +#include #include namespace Sails { @@ -60,13 +61,15 @@ namespace Sails { ResidueData(const std::vector &acceptors, const std::vector &donors, std::string &snfg_shape, std::string &snfg_colour, std::vector &preferred_depths, std::string &anomer, - std::string &wurcs, bool special + std::string &wurcs, bool special, bool is_sugar ) : acceptors(acceptors), donors(donors), snfg_shape(std::move(snfg_shape)), snfg_colour(std::move(snfg_colour)), preferred_depths(preferred_depths), anomer(anomer), - special(special) { + special(special), + is_sugar(is_sugar) + { if (!wurcs.empty()) {wurcs_code = wurcs;} @@ -89,6 +92,7 @@ namespace Sails { std::vector preferred_depths; std::string anomer; bool special; + bool is_sugar; std::optional wurcs_code = std::nullopt; }; @@ -219,6 +223,25 @@ namespace Sails { typedef std::map > LinkageDatabase; + /** @brief Find protein donors in LinkageDatabase + * + */ + inline std::set find_protein_donors(LinkageDatabase &linkage_database) { + std::set acceptor_names = {}; + std::set donor_names = {}; + for (const auto& [donor_name, linkages]: linkage_database) { + for (auto& linkage: linkages) { + acceptor_names.insert(linkage.acceptor); + } + donor_names.insert(donor_name); + } + std::set difference = {}; + std::set_difference(donor_names.begin(), donor_names.end(), acceptor_names.begin(), + acceptor_names.end(), std::inserter(difference, difference.begin())); + return difference; + } + + /** * @class Glycosite * @brief A class representing a glycosite. diff --git a/package/src/include/sails-predictions.h b/package/src/include/sails-predictions.h new file mode 100644 index 0000000..28d2561 --- /dev/null +++ b/package/src/include/sails-predictions.h @@ -0,0 +1,48 @@ +// +// Created by Jordan Dialpuri on 07/10/2025. +// + +#ifndef SAILS_PREDICTIONS_H +#define SAILS_PREDICTIONS_H + +#include +#include +#include "sails-model.h" +#include "sails-utils.h" + +namespace Sails { + + class Predictions { + public: + explicit Predictions(gemmi::Grid<>* glycan_map, LinkageDatabase& linkage_database, ResidueDatabase& residue_database): m_residue_database(residue_database) { + protein_donors = find_protein_donors(linkage_database); + m_glycan_map = glycan_map; + }; + + explicit Predictions(gemmi::Grid<>* glycan_map, gemmi::Grid<>* protein_map, LinkageDatabase& linkage_database, ResidueDatabase& residue_database): m_residue_database(residue_database) { + protein_donors = find_protein_donors(linkage_database); + m_glycan_map = glycan_map; + m_protein_map = protein_map; + }; + + Glycosites find_potential_sites(gemmi::Structure &structure, bool use_glycan); + + private: + std::optional create_neighbour_search(gemmi::Grid<> *grid, float threshold, + const gemmi::UnitCell &unit_cell); + + Glycosites find_potential_sites_using_glycan(gemmi::Structure &structure); + + Glycosites find_potential_sites_using_protein(gemmi::Structure &structure); + + + gemmi::Grid<>* m_glycan_map = nullptr; + gemmi::Grid<>* m_protein_map = nullptr; + std::set protein_donors; + Sails::ResidueDatabase& m_residue_database; + }; + +} + + +#endif //SAILS_PREDICTIONS_H diff --git a/package/src/include/sails-refine.h b/package/src/include/sails-refine.h index b4e7241..9ff86d7 100644 --- a/package/src/include/sails-refine.h +++ b/package/src/include/sails-refine.h @@ -12,7 +12,7 @@ #include #include -#include +#include namespace Sails { diff --git a/package/src/include/sails-score.h b/package/src/include/sails-score.h new file mode 100644 index 0000000..94b3194 --- /dev/null +++ b/package/src/include/sails-score.h @@ -0,0 +1,30 @@ +// +// Created by Jordan Dialpuri on 22/10/2025. +// + +#ifndef SAILS_SCORE_H +#define SAILS_SCORE_H +#include "sails-model.h" +#include "density/sails-density.h" + +namespace Sails::Score { + + std::map calculate_rsccs(Sails::Density* density, gemmi::Structure* structure, ResidueDatabase &residue_database); + + std::map calculate_qscores(Sails::Density* density, gemmi::Structure* structure, ResidueDatabase &residue_database); + + double calculate_clash_score(Sails::Glycosite &site, gemmi::Structure* structure); + + namespace QScore { + std::vector fibonacci_sphere(int samples, float radius, const gemmi::Position ¢er); + + std::vector get_radial_points(const gemmi::Position &position, float radius, int N, Glycosite& site, gemmi::NeighborSearch& ns); + + std::vector sample_density(const gemmi::Grid<> *grid, std::vector& positions); + + double calculate_q_score(const gemmi::Position & position, Glycosite &site, const gemmi::Grid<> *grid, + gemmi::NeighborSearch &ns, float A, float B, float sigma, int N); + } +} + +#endif //SAILS_SCORE_H diff --git a/package/src/include/sails-telemetry.h b/package/src/include/sails-telemetry.h index 99956bc..d828fbf 100644 --- a/package/src/include/sails-telemetry.h +++ b/package/src/include/sails-telemetry.h @@ -18,17 +18,27 @@ namespace Sails { struct TelemetryFormat { TelemetryFormat() = default; - TelemetryFormat(const std::string &residue_id, double rscc_score, double rsr_score, double dds_score) + TelemetryFormat(const std::string &residue_id, double rscc_score, double rsr_score, double q_score) : residue_id(residue_id), rscc_score(rscc_score), rsr_score(rsr_score), - dds_score(dds_score) { + q_score(q_score) { + } + + TelemetryFormat(const std::string &residue_id, double rscc_score) + : residue_id(residue_id), + rscc_score(rscc_score), rsr_score(0), q_score(0) { + } + + TelemetryFormat(const std::string &residue_id, double rscc_score, double q_score) + : residue_id(residue_id), + rscc_score(rscc_score), rsr_score(0), q_score(q_score) { } std::string residue_id; double rscc_score; double rsr_score; - double dds_score; + double q_score; }; typedef std::map> TelemetryLog; @@ -153,6 +163,7 @@ namespace Sails { */ void format_log(gemmi::Structure* structure); + static std::optional format_log(std::vector& log, bool write, const std::string& filepath); /** * @brief Calculates the telemetry log for Sails. diff --git a/package/src/include/sails-utils.h b/package/src/include/sails-utils.h index ef323fd..4245b35 100644 --- a/package/src/include/sails-utils.h +++ b/package/src/include/sails-utils.h @@ -257,6 +257,36 @@ namespace Sails::Utils { * @return a vector of strings split by the delimiter */ std::vector split(const std::string &string, char delimiter); + + gemmi::Model create_model(gemmi::Residue& residue); + + + template + std::pair, std::vector> split_pairs(const std::vector> &pairs) { + std::vector firsts; + std::vector seconds; + firsts.reserve(pairs.size()); + seconds.reserve(pairs.size()); + for (const auto& p : pairs) { + firsts.push_back(p.first); + seconds.push_back(p.second); + } + return {std::move(firsts), std::move(seconds)}; + } + + inline double calculate_average_bfactor(const Glycosite &site, gemmi::Structure * structure) { + gemmi::Residue* residue_ptr = get_residue_ptr_from_glycosite(site, structure); + const double sum = std::accumulate(residue_ptr->atoms.begin(), residue_ptr->atoms.end(), 0.0, [](const double current, gemmi::Atom& atom) { + return current + atom.b_iso; + }); + return sum / residue_ptr->atoms.size(); + } + + inline void set_all_bfactors(gemmi::Residue * residue, double b_factor) { + for (auto & atom : residue->atoms) { + atom.b_iso = b_factor; + } + } } // namespace Sails::Utils diff --git a/package/src/sails/__init__.py b/package/src/sails/__init__.py index c21863f..2b5b677 100644 --- a/package/src/sails/__init__.py +++ b/package/src/sails/__init__.py @@ -26,6 +26,14 @@ find_wurcs, model_wurcs, morph, + identify_predicted_sites, + auto_glycosylate, + validate, + glycosylate_site, + Connections, + AtomAddress, + ResidueId, + validate_site, ) from .__version__ import __version__ from .glycosylate import glycosylate_xtal, glycosylate_em, Type @@ -79,4 +87,12 @@ "find_wurcs", "model_wurcs", "morph", + "identify_predicted_sites", + "auto_glycosylate", + "validate", + "glycosylate_site", + "Connections", + "AtomAddress", + "ResidueId", + "validate_site", ] diff --git a/package/src/sails/clean.py b/package/src/sails/clean.py new file mode 100644 index 0000000..27a4ad4 --- /dev/null +++ b/package/src/sails/clean.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Jordan Dialpuri, Jon Agirre, Kathryn Cowtan, Paul Bond and University of York. All rights reserved + +import site +import os + + +def clean_models(): + site_packages_dir = site.getsitepackages() + + found_models = [] + + for folder in site_packages_dir: + sails_model_dir = os.path.join(folder, "sails_models") + if os.path.exists(sails_model_dir): + for model in os.scandir(sails_model_dir): + found_models.append(model) + + if not found_models: + print("No models were found in site-packages.") + return + + print("Pick an option to remove: ") + + for index, model in enumerate(found_models): + print(f"{index + 1}) {model.name.removesuffix('.onnx')}") + + if len(found_models) > 1: + print(f"{len(found_models) + 1}) All") + + option_selected = False + while not option_selected: + option = input("Number: ") + + try: + choice = int(option) + if choice <= 0 or choice > len(found_models) + 1: + raise ValueError() + + if choice == len(found_models) + 1: + print("Do you want to remove all the models?") + else: + model_to_remove = found_models[choice - 1] + print(f"Confirm you want to remove {model_to_remove.name}?") + + y_no_selected = False + confirm = False + while not y_no_selected: + y_or_n = input("Y/N ").lower() + if y_or_n not in ["y", "yes", "n", "no"]: + continue + + y_no_selected = True + + if y_or_n == "y" or y_or_n == "yes": + confirm = True + + if confirm: + if choice == len(found_models) + 1: + for model in found_models: + os.remove(model.path) + print("Removed", model.name) + else: + model_to_remove = found_models[choice - 1] + os.remove(model_to_remove.path) + print("Removed", model_to_remove.name) + + option_selected = True + except ValueError: + print("Invalid choice") + + +def run(): + clean_models() diff --git a/package/src/sails/data/data.json b/package/src/sails/data/data.json index e120170..f124461 100644 --- a/package/src/sails/data/data.json +++ b/package/src/sails/data/data.json @@ -10,6 +10,7 @@ "anomer": "α", "wurcsCode": "", "special": false, + "isSugar": true, "donorSets": [], "acceptorSets": [ { @@ -30,6 +31,7 @@ "anomer": "β", "wurcsCode": "", "special": false, + "isSugar": true, "donorSets": [ { "atom3": "O6", @@ -57,6 +59,7 @@ "anomer": "α", "wurcsCode": "a1221m-1a_1-5", "special": true, + "isSugar": true, "donorSets": [], "acceptorSets": [ { @@ -81,6 +84,7 @@ "anomer": "α", "wurcsCode": "a1122h-1a_1-5", "special": false, + "isSugar": true, "donorSets": [ { "atom3": "O2", @@ -126,6 +130,7 @@ "anomer": "α", "wurcsCode": "a1122h-1a_1-5*", "special": false, + "isSugar": true, "donorSets": [], "acceptorSets": [ { @@ -146,6 +151,7 @@ "anomer": "β", "wurcsCode": "a1122h-1b_1-5", "special": false, + "isSugar": true, "donorSets": [ { "atom3": "O6", @@ -175,11 +181,13 @@ "snfgColour": "#0090bc", "preferredDepths": [ 1, - 2 + 2, + 7 ], "anomer": "β", "wurcsCode": "a2122h-1b_1-5_2*NCC/3=O", "special": false, + "isSugar": true, "donorSets": [ { "atom1": "C5", @@ -219,6 +227,7 @@ "anomer": "", "wurcsCode": "", "special": false, + "isSugar": false, "donorSets": [ { "atom1": "CB", @@ -239,6 +248,7 @@ "anomer": "", "wurcsCode": "", "special": false, + "isSugar": false, "donorSets": [ { "atom1": "CD2", @@ -259,6 +269,7 @@ "anomer": "", "wurcsCode": "", "special": false, + "isSugar": false, "donorSets": [ { "atom1": "CA", @@ -279,6 +290,7 @@ "anomer": "", "wurcsCode": "", "special": false, + "isSugar": false, "donorSets": [ { "atom1": "CA", @@ -300,39 +312,60 @@ "clusters": [ { "angles": { - "alphaMean": 127.765, - "alphaStdDev": 8.153, - "betaMean": 112.241, - "betaStdDev": 8.153, - "gammaMean": 113.442, - "gammaStdDev": 1.689 + "alphaMean": 127.742, + "alphaStdDev": 5.725, + "betaMean": 109.106, + "betaStdDev": 5.725, + "gammaMean": 114.671, + "gammaStdDev": 1.942 }, "torsions": { - "phiMean": 54.086, - "phiStdDev": 36.298, - "psiMean": -179.512, - "psiStdDev": 53.054, - "omegaMean": 167.981, - "omegaStdDev": 11.815 - } + "phiMean": -97.7, + "phiStdDev": 27.286, + "psiMean": 359.451, + "psiStdDev": 6.833, + "omegaMean": -179.573, + "omegaStdDev": 8.716 + }, + "priority": false }, { "angles": { - "alphaMean": 122.28, - "alphaStdDev": 5.253, - "betaMean": 109.788, - "betaStdDev": 5.253, - "gammaMean": 113.963, - "gammaStdDev": 1.54 + "alphaMean": 121.376, + "alphaStdDev": 4.178, + "betaMean": 109.598, + "betaStdDev": 4.178, + "gammaMean": 113.848, + "gammaStdDev": 1.686 }, "torsions": { - "phiMean": -106.706, - "phiStdDev": 30.219, - "psiMean": 179.446, - "psiStdDev": 28.245, - "omegaMean": 176.266, - "omegaStdDev": 7.108 - } + "phiMean": -102.636, + "phiStdDev": 25.777, + "psiMean": 179.124, + "psiStdDev": 6.129, + "omegaMean": 176.271, + "omegaStdDev": 6.38 + }, + "priority": true + }, + { + "angles": { + "alphaMean": 125.118, + "alphaStdDev": 4.63, + "betaMean": 112.176, + "betaStdDev": 4.63, + "gammaMean": 113.306, + "gammaStdDev": 2.018 + }, + "torsions": { + "phiMean": 57.633, + "phiStdDev": 16.731, + "psiMean": 180.055, + "psiStdDev": 4.686, + "omegaMean": 168.839, + "omegaStdDev": 8.195 + }, + "priority": false } ] }, @@ -345,21 +378,22 @@ "clusters": [ { "angles": { - "alphaMean": 127.838, - "alphaStdDev": 5.582, - "betaMean": 109.84, - "betaStdDev": 5.582, - "gammaMean": 116.699, - "gammaStdDev": 2.541 + "alphaMean": 125.332, + "alphaStdDev": 3.684, + "betaMean": 109.285, + "betaStdDev": 3.684, + "gammaMean": 117.773, + "gammaStdDev": 1.929 }, "torsions": { - "phiMean": 116.884, - "phiStdDev": 25.889, - "psiMean": -176.651, - "psiStdDev": 11.025, - "omegaMean": 162.145, - "omegaStdDev": 35.401 - } + "phiMean": 120.063, + "phiStdDev": 13.997, + "psiMean": 183.105, + "psiStdDev": 5.923, + "omegaMean": 166.487, + "omegaStdDev": 27.968 + }, + "priority": true } ] }, @@ -372,21 +406,22 @@ "clusters": [ { "angles": { - "alphaMean": 113.548, - "alphaStdDev": 5.876, - "betaMean": 111.95, - "betaStdDev": 5.876, - "gammaMean": 114.725, - "gammaStdDev": 1.774 + "alphaMean": 113.152, + "alphaStdDev": 4.541, + "betaMean": 112.019, + "betaStdDev": 4.541, + "gammaMean": 114.806, + "gammaStdDev": 1.708 }, "torsions": { - "phiMean": 72.866, - "phiStdDev": 28.094, - "psiMean": 175.3, - "psiStdDev": 33.92, - "omegaMean": 60.137, - "omegaStdDev": 10.249 - } + "phiMean": 69.173, + "phiStdDev": 21.872, + "psiMean": 172.894, + "psiStdDev": 30.103, + "omegaMean": 60.268, + "omegaStdDev": 9.139 + }, + "priority": true } ] }, @@ -399,21 +434,22 @@ "clusters": [ { "angles": { - "alphaMean": 113.367, - "alphaStdDev": 6.634, - "betaMean": 111.89, - "betaStdDev": 6.634, - "gammaMean": 114.592, - "gammaStdDev": 2.038 + "alphaMean": 112.707, + "alphaStdDev": 5.556, + "betaMean": 111.711, + "betaStdDev": 5.556, + "gammaMean": 114.497, + "gammaStdDev": 1.981 }, "torsions": { - "phiMean": 86.251, - "phiStdDev": 27.802, - "psiMean": 130.215, - "psiStdDev": 25.722, - "omegaMean": 61.975, - "omegaStdDev": 11.282 - } + "phiMean": 88.006, + "phiStdDev": 28.579, + "psiMean": 129.928, + "psiStdDev": 25.827, + "omegaMean": 62.706, + "omegaStdDev": 10.134 + }, + "priority": true } ] }, @@ -426,39 +462,60 @@ "clusters": [ { "angles": { - "alphaMean": 117.289, - "alphaStdDev": 7.737, - "betaMean": 110.997, - "betaStdDev": 7.737, - "gammaMean": 113.811, - "gammaStdDev": 1.418 + "alphaMean": 118.5, + "alphaStdDev": 4.218, + "betaMean": 110.849, + "betaStdDev": 4.218, + "gammaMean": 114.828, + "gammaStdDev": 1.339 }, "torsions": { - "phiMean": 30.089, - "phiStdDev": 51.985, - "psiMean": -124.86, - "psiStdDev": 27.34, - "omegaMean": 175.963, - "omegaStdDev": 7.299 - } + "phiMean": -92.052, + "phiStdDev": 13.796, + "psiMean": 61.167, + "psiStdDev": 15.518, + "omegaMean": -178.618, + "omegaStdDev": 6.176 + }, + "priority": false }, { "angles": { - "alphaMean": 112.665, - "alphaStdDev": 4.676, - "betaMean": 111.191, - "betaStdDev": 4.676, - "gammaMean": 113.81, - "gammaStdDev": 1.316 + "alphaMean": 113.04, + "alphaStdDev": 3.917, + "betaMean": 111.017, + "betaStdDev": 3.917, + "gammaMean": 113.571, + "gammaStdDev": 1.718 }, "torsions": { - "phiMean": -80.049, - "phiStdDev": 17.934, - "psiMean": -127.513, - "psiStdDev": 24.467, - "omegaMean": 179.044, - "omegaStdDev": 4.33 - } + "phiMean": -79.724, + "phiStdDev": 16.764, + "psiMean": -127.925, + "psiStdDev": 15.296, + "omegaMean": 178.681, + "omegaStdDev": 4.643 + }, + "priority": true + }, + { + "angles": { + "alphaMean": 120.232, + "alphaStdDev": 3.25, + "betaMean": 109.47, + "betaStdDev": 3.25, + "gammaMean": 114.125, + "gammaStdDev": 2.327 + }, + "torsions": { + "phiMean": 85.744, + "phiStdDev": 12.098, + "psiMean": -109.446, + "psiStdDev": 7.651, + "omegaMean": 172.163, + "omegaStdDev": 6.743 + }, + "priority": false } ] }, @@ -471,39 +528,41 @@ "clusters": [ { "angles": { - "alphaMean": 112.045, - "alphaStdDev": 5.979, - "betaMean": 111.561, - "betaStdDev": 5.979, - "gammaMean": 114.921, - "gammaStdDev": 1.979 + "alphaMean": 111.943, + "alphaStdDev": 3.334, + "betaMean": 111.523, + "betaStdDev": 3.334, + "gammaMean": 115.16, + "gammaStdDev": 1.963 }, "torsions": { - "phiMean": -83.971, - "phiStdDev": 30.056, - "psiMean": -159.668, - "psiStdDev": 27.959, - "omegaMean": -63.8, - "omegaStdDev": 8.514 - } + "phiMean": -73.943, + "phiStdDev": 8.372, + "psiMean": -174.486, + "psiStdDev": 11.471, + "omegaMean": -62.223, + "omegaStdDev": 4.214 + }, + "priority": true }, { "angles": { - "alphaMean": 111.376, - "alphaStdDev": 3.955, - "betaMean": 112.227, - "betaStdDev": 3.955, - "gammaMean": 115.159, - "gammaStdDev": 2.054 + "alphaMean": 111.501, + "alphaStdDev": 2.805, + "betaMean": 112.175, + "betaStdDev": 2.805, + "gammaMean": 115.277, + "gammaStdDev": 1.588 }, "torsions": { - "phiMean": -81.232, - "phiStdDev": 30.292, - "psiMean": 125.964, - "psiStdDev": 23.346, - "omegaMean": -65.062, - "omegaStdDev": 8.861 - } + "phiMean": -72.707, + "phiStdDev": 5.713, + "psiMean": 125.347, + "psiStdDev": 17.233, + "omegaMean": -62.404, + "omegaStdDev": 3.752 + }, + "priority": false } ] }, @@ -516,39 +575,22 @@ "clusters": [ { "angles": { - "alphaMean": 114.468, - "alphaStdDev": 3.421, - "betaMean": 113.049, - "betaStdDev": 3.421, - "gammaMean": 115.227, - "gammaStdDev": 2.164 + "alphaMean": 114.315, + "alphaStdDev": 4.321, + "betaMean": 113.357, + "betaStdDev": 4.321, + "gammaMean": 114.783, + "gammaStdDev": 2.161 }, "torsions": { - "phiMean": -67.902, - "phiStdDev": 13.039, - "psiMean": 139.754, - "psiStdDev": 8.787, - "omegaMean": -65.706, - "omegaStdDev": 7.385 - } - }, - { - "angles": { - "alphaMean": 120.567, - "alphaStdDev": 5.481, - "betaMean": 110.511, - "betaStdDev": 5.481, - "gammaMean": 115.25, - "gammaStdDev": 2.348 + "phiMean": -70.877, + "phiStdDev": 12.529, + "psiMean": 137.662, + "psiStdDev": 14.607, + "omegaMean": -63.846, + "omegaStdDev": 6.328 }, - "torsions": { - "phiMean": -108.059, - "phiStdDev": 44.452, - "psiMean": 10.613, - "psiStdDev": 44.212, - "omegaMean": -75.286, - "omegaStdDev": 13.871 - } + "priority": true } ] }, @@ -561,39 +603,60 @@ "clusters": [ { "angles": { - "alphaMean": 112.963, - "alphaStdDev": 4.496, - "betaMean": 111.259, - "betaStdDev": 4.496, - "gammaMean": 113.771, - "gammaStdDev": 1.62 + "alphaMean": 118.887, + "alphaStdDev": 3.828, + "betaMean": 110.238, + "betaStdDev": 3.828, + "gammaMean": 115.762, + "gammaStdDev": 2.336 }, "torsions": { - "phiMean": -83.669, - "phiStdDev": 38.257, - "psiMean": -131.478, - "psiStdDev": 14.602, - "omegaMean": 178.472, - "omegaStdDev": 6.132 - } + "phiMean": -117.12, + "phiStdDev": 17.269, + "psiMean": 69.913, + "psiStdDev": 15.995, + "omegaMean": -177.385, + "omegaStdDev": 10.382 + }, + "priority": false }, { "angles": { - "alphaMean": 118.402, - "alphaStdDev": 5.23, - "betaMean": 111.058, - "betaStdDev": 5.23, - "gammaMean": 114.405, - "gammaStdDev": 1.96 + "alphaMean": 112.92, + "alphaStdDev": 3.717, + "betaMean": 110.727, + "betaStdDev": 3.717, + "gammaMean": 113.734, + "gammaStdDev": 1.855 }, "torsions": { - "phiMean": -110.737, - "phiStdDev": 41.555, - "psiMean": 78.36, - "psiStdDev": 25.132, - "omegaMean": 179.908, - "omegaStdDev": 8.293 - } + "phiMean": -84.638, + "phiStdDev": 16.181, + "psiMean": -131.408, + "psiStdDev": 14.169, + "omegaMean": 179.148, + "omegaStdDev": 5.325 + }, + "priority": true + }, + { + "angles": { + "alphaMean": 117.773, + "alphaStdDev": 3.189, + "betaMean": 113.615, + "betaStdDev": 3.189, + "gammaMean": 111.687, + "gammaStdDev": 1.326 + }, + "torsions": { + "phiMean": 51.356, + "phiStdDev": 9.507, + "psiMean": -121.891, + "psiStdDev": 7.708, + "omegaMean": 171.466, + "omegaStdDev": 3.746 + }, + "priority": false } ] }, @@ -606,39 +669,41 @@ "clusters": [ { "angles": { - "alphaMean": 112.696, - "alphaStdDev": 4.802, - "betaMean": 111.185, - "betaStdDev": 4.802, - "gammaMean": 114.511, - "gammaStdDev": 1.534 + "alphaMean": 111.884, + "alphaStdDev": 3.674, + "betaMean": 112.374, + "betaStdDev": 3.674, + "gammaMean": 115.079, + "gammaStdDev": 2.103 }, "torsions": { - "phiMean": 124.141, - "phiStdDev": 27.455, - "psiMean": 167.284, - "psiStdDev": 52.376, - "omegaMean": 62.53, - "omegaStdDev": 6.109 - } + "phiMean": 79.668, + "phiStdDev": 22.254, + "psiMean": 122.875, + "psiStdDev": 21.525, + "omegaMean": 63.173, + "omegaStdDev": 6.833 + }, + "priority": true }, { "angles": { - "alphaMean": 111.777, - "alphaStdDev": 4.176, - "betaMean": 112.14, - "betaStdDev": 4.176, - "gammaMean": 114.921, - "gammaStdDev": 1.929 + "alphaMean": 116.704, + "alphaStdDev": 2.744, + "betaMean": 110.516, + "betaStdDev": 2.744, + "gammaMean": 114.326, + "gammaStdDev": 1.492 }, "torsions": { - "phiMean": 76.691, - "phiStdDev": 17.032, - "psiMean": 120.879, - "psiStdDev": 18.587, - "omegaMean": 63.32, - "omegaStdDev": 6.55 - } + "phiMean": 115.031, + "phiStdDev": 15.957, + "psiMean": -45.776, + "psiStdDev": 18.655, + "omegaMean": 60.43, + "omegaStdDev": 5.04 + }, + "priority": false } ] }, @@ -651,39 +716,41 @@ "clusters": [ { "angles": { - "alphaMean": 111.245, - "alphaStdDev": 5.137, - "betaMean": 111.618, - "betaStdDev": 5.137, - "gammaMean": 114.578, - "gammaStdDev": 2.196 + "alphaMean": 111.083, + "alphaStdDev": 3.186, + "betaMean": 111.653, + "betaStdDev": 3.186, + "gammaMean": 114.791, + "gammaStdDev": 1.861 }, "torsions": { - "phiMean": 91.041, - "phiStdDev": 36.251, - "psiMean": -165.287, - "psiStdDev": 29.928, - "omegaMean": 62.662, - "omegaStdDev": 6.365 - } + "phiMean": 75.669, + "phiStdDev": 23.342, + "psiMean": 178.455, + "psiStdDev": 13.232, + "omegaMean": 61.6, + "omegaStdDev": 5.504 + }, + "priority": true }, { "angles": { - "alphaMean": 111.215, - "alphaStdDev": 5.15, - "betaMean": 111.891, - "betaStdDev": 5.15, - "gammaMean": 114.514, - "gammaStdDev": 1.803 + "alphaMean": 112.288, + "alphaStdDev": 3.591, + "betaMean": 113.29, + "betaStdDev": 3.591, + "gammaMean": 114.578, + "gammaStdDev": 1.955 }, "torsions": { - "phiMean": 99.177, - "phiStdDev": 38.324, - "psiMean": 122.388, - "psiStdDev": 28.196, - "omegaMean": 62.79, - "omegaStdDev": 6.586 - } + "phiMean": 84.192, + "phiStdDev": 19.648, + "psiMean": 87.973, + "psiStdDev": 8.041, + "omegaMean": 59.971, + "omegaStdDev": 5.414 + }, + "priority": false } ] }, @@ -696,39 +763,41 @@ "clusters": [ { "angles": { - "alphaMean": 113.427, - "alphaStdDev": 3.654, - "betaMean": 110.13, - "betaStdDev": 3.654, - "gammaMean": 114.421, - "gammaStdDev": 1.585 + "alphaMean": 111.999, + "alphaStdDev": 3.669, + "betaMean": 111.888, + "betaStdDev": 3.669, + "gammaMean": 114.916, + "gammaStdDev": 1.883 }, "torsions": { - "phiMean": 149.427, - "phiStdDev": 29.129, - "psiMean": 166.702, - "psiStdDev": 28.999, - "omegaMean": 62.546, - "omegaStdDev": 7.578 - } + "phiMean": 81.869, + "phiStdDev": 11.331, + "psiMean": 135.629, + "psiStdDev": 18.959, + "omegaMean": 61.031, + "omegaStdDev": 4.842 + }, + "priority": true }, { "angles": { - "alphaMean": 111.736, - "alphaStdDev": 3.807, - "betaMean": 111.726, - "betaStdDev": 3.807, - "gammaMean": 114.934, - "gammaStdDev": 1.849 + "alphaMean": 116.877, + "alphaStdDev": 2.86, + "betaMean": 114.58, + "betaStdDev": 2.86, + "gammaMean": 114.957, + "gammaStdDev": 2.177 }, "torsions": { - "phiMean": 81.341, - "phiStdDev": 15.082, - "psiMean": 127.531, - "psiStdDev": 21.67, - "omegaMean": 62.113, - "omegaStdDev": 5.308 - } + "phiMean": 61.352, + "phiStdDev": 3.213, + "psiMean": 68.128, + "psiStdDev": 4.275, + "omegaMean": 61.197, + "omegaStdDev": 3.494 + }, + "priority": false } ] }, @@ -741,39 +810,41 @@ "clusters": [ { "angles": { - "alphaMean": 111.267, - "alphaStdDev": 4.451, - "betaMean": 112.338, - "betaStdDev": 4.451, - "gammaMean": 115.005, - "gammaStdDev": 1.642 + "alphaMean": 112.28, + "alphaStdDev": 4.978, + "betaMean": 112.174, + "betaStdDev": 4.978, + "gammaMean": 114.722, + "gammaStdDev": 1.941 }, "torsions": { - "phiMean": 87.585, - "phiStdDev": 26.284, - "psiMean": 127.542, - "psiStdDev": 21.99, - "omegaMean": 63.792, - "omegaStdDev": 6.819 - } + "phiMean": 81.551, + "phiStdDev": 19.6, + "psiMean": 130.234, + "psiStdDev": 22.768, + "omegaMean": 61.843, + "omegaStdDev": 6.198 + }, + "priority": true }, { "angles": { - "alphaMean": 117.959, - "alphaStdDev": 8.512, - "betaMean": 110.22, - "betaStdDev": 8.512, - "gammaMean": 114.023, - "gammaStdDev": 1.649 + "alphaMean": 116.969, + "alphaStdDev": 2.361, + "betaMean": 110.661, + "betaStdDev": 2.361, + "gammaMean": 114.473, + "gammaStdDev": 1.862 }, "torsions": { - "phiMean": 109.582, - "phiStdDev": 28.224, - "psiMean": -37.08, - "psiStdDev": 30.971, - "omegaMean": 60.351, - "omegaStdDev": 8.854 - } + "phiMean": 108.755, + "phiStdDev": 22.812, + "psiMean": -41.195, + "psiStdDev": 22.631, + "omegaMean": 59.388, + "omegaStdDev": 7.295 + }, + "priority": false } ] }, @@ -786,38 +857,67 @@ "clusters": [ { "angles": { - "alphaMean": 112.164, - "alphaStdDev": 5.285, - "betaMean": 111.948, - "betaStdDev": 5.285, - "gammaMean": 114.89, - "gammaStdDev": 1.67 + "alphaMean": 111.914, + "alphaStdDev": 3.471, + "betaMean": 111.639, + "betaStdDev": 3.471, + "gammaMean": 114.69, + "gammaStdDev": 2.042 }, "torsions": { - "phiMean": 69.439, - "phiStdDev": 16.826, - "psiMean": 178.823, - "psiStdDev": 36.578, - "omegaMean": 63.355, - "omegaStdDev": 7.64 - } + "phiMean": 69.712, + "phiStdDev": 11.052, + "psiMean": -174.851, + "psiStdDev": 12.285, + "omegaMean": 60.498, + "omegaStdDev": 5.127 + }, + "priority": true }, { "angles": { - "alphaMean": 110.925, - "alphaStdDev": 3.37, - "betaMean": 110.781, - "betaStdDev": 3.37, - "gammaMean": 114.455, - "gammaStdDev": 1.614 + "alphaMean": 107.91, + "alphaStdDev": 1.048, + "betaMean": 110.222, + "betaStdDev": 1.048, + "gammaMean": 115.455, + "gammaStdDev": 0.482 + }, + "torsions": { + "phiMean": 120.209, + "phiStdDev": 10.104, + "psiMean": -143.923, + "psiStdDev": 5.034, + "omegaMean": 67.549, + "omegaStdDev": 2.056 + }, + "priority": false + } + ] + }, + { + "donorResidue": "MAN", + "acceptorResidue": "NAG", + "donorNumber": 2, + "acceptorNumber": 1, + "length": 1.4, + "clusters": [ + { + "angles": { + "alphaMean": 112.374, + "alphaStdDev": 4.559, + "betaMean": 110.391, + "betaStdDev": 4.559, + "gammaMean": 113.604, + "gammaStdDev": 1.857 }, "torsions": { - "phiMean": 145.024, - "phiStdDev": 29.365, - "psiMean": -169.354, - "psiStdDev": 54.477, - "omegaMean": 63.742, - "omegaStdDev": 7.224 + "phiMean": -85.243, + "phiStdDev": 20.485, + "psiMean": 148.418, + "psiStdDev": 20.213, + "omegaMean": 178.521, + "omegaStdDev": 5.397 } } ] diff --git a/package/src/sails/find.py b/package/src/sails/find.py index 0b40fbe..90f4220 100644 --- a/package/src/sails/find.py +++ b/package/src/sails/find.py @@ -1,9 +1,17 @@ +import importlib +import time +from argparse import ArgumentError from collections import defaultdict from pathlib import Path from typing import List, Tuple import gemmi import argparse import json +from sails import identify_predicted_sites, GlycoSite +from .interface import get_sails_structure, get_sails_map +from .glycosylate import read_prediction_dir, save_log +from .prediction.model import ModelType +from .prediction.predict import predict_map def find_n_glycosylation_sites(structure: gemmi.Structure): @@ -94,27 +102,7 @@ def format_sites( return d -def run(): - """ - Parse command-line arguments, read PDB model, find glycosylation sites, - and write the results to an output file in JSON format. - - :return: None - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "-modelin", required=True, type=str, help="Path to a model in PDB or CIF format" - ) - parser.add_argument( - "-logout", - required=False, - default="sites.json", - type=str, - help="Path to output file", - ) - - args = parser.parse_args() - +def sequence_find(args: argparse.Namespace): pdb_path = Path(args.modelin) if not pdb_path.exists(): raise FileNotFoundError("Could not find specified file") @@ -132,3 +120,240 @@ def run(): with open(args.logout, "w") as f: json.dump(data, f, indent=4) + + +def convert_residue_name_to_type(residue_name: str) -> str: + n_glycans = ["ASN"] + o_glycans = ["SER", "THR"] + c_glycans = ["TRP"] + + if residue_name in n_glycans: + return "n-glycan" + elif residue_name in o_glycans: + return "o-glycan" + elif residue_name in c_glycans: + return "c-glycan" + return "x-glycan" + + +def convert_glycosites_to_log( + glycosites: List[GlycoSite], structure: gemmi.Structure | Path | str +): + if isinstance(structure, str) or isinstance(structure, Path): + structure = gemmi.read_structure(str(structure)) + + keys = defaultdict(list) + for glycosite in glycosites: + model = structure[glycosite.model_idx] + chain = model[glycosite.chain_idx] + residue = chain[glycosite.residue_idx] + key = f"{chain.name}-{residue.name}-{residue.seqid.num}" + keys[convert_residue_name_to_type(residue.name)].append(key) + + return keys + + +def get_amplitude_phase(args): + if "," not in args.colin_fwt: + raise ArgumentError("FWT column should be comma separated") + return args.colin_fwt.split(",") + + +def xray(args): + sails_structure = get_sails_structure(args.modelin) + resource = importlib.resources.files("sails").joinpath("data") + model = ModelType[args.modeltype] + if args.preddirin: + predictions = read_prediction_dir(args.preddirin, model) + else: + amplitude, phase = get_amplitude_phase(args) + predictions = predict_map( + model.name, + args.mtzin, + "output", + nthreads=8, + amplitude=amplitude, + phase=phase, + save_map=True, + ) + + if model == ModelType.binary: + glycan_predicted_map = predictions + sails_grid = get_sails_map(glycan_predicted_map) + result = identify_predicted_sites(sails_structure, sails_grid, str(resource)) + else: + glycan_predicted_map, protein_predicted_map = predictions + sails_glycan_grid = get_sails_map(glycan_predicted_map) + sails_protein_grid = get_sails_map(protein_predicted_map) + searchtype = args.searchtype + result = identify_predicted_sites( + sails_structure, + sails_glycan_grid, + sails_protein_grid, + searchtype == "glycan", + str(resource), + ) + + log = convert_glycosites_to_log(result, args.modelin) + save_log(log, args) + + +def em(args): + sails_structure = get_sails_structure(args.modelin) + resource = importlib.resources.files("sails").joinpath("data") + model = ModelType[args.modeltype] + + if args.preddirin: + predictions = read_prediction_dir(args.preddirin, model) + else: + predictions = predict_map( + model.name, + args.mapin, + "output", + nthreads=8, + save_map=True, + ) + + if model == ModelType.binary: + glycan_predicted_map = predictions + sails_grid = get_sails_map(glycan_predicted_map) + result = identify_predicted_sites(sails_structure, sails_grid, str(resource)) + else: + glycan_predicted_map, protein_predicted_map = predictions + sails_glycan_grid = get_sails_map(glycan_predicted_map) + sails_protein_grid = get_sails_map(protein_predicted_map) + searchtype = args.searchtype + result = identify_predicted_sites( + sails_structure, + sails_glycan_grid, + sails_protein_grid, + searchtype == "glycan", + str(resource), + ) + + log = convert_glycosites_to_log(result, args.modelin) + save_log(log, args) + + +def density_find(args: argparse.Namespace): + t0 = time.time() + + if args.source == "xray": + xray(args) + elif args.source == "em": + em(args) + else: + raise RuntimeError("Unknown mode") + + t1 = time.time() + print(f"Sails Density Identification - Time Taken = {(t1 - t0)} seconds") + + +def run(): + """ + Parse command-line arguments, read PDB model, find glycosylation sites, + and write the results to an output file in JSON format. + + :return: None + """ + + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(dest="mode", required=True) + + seq_parser = subparsers.add_parser("seq") + seq_parser.add_argument( + "--modelin", + required=True, + type=str, + help="Path to a model in PDB or CIF format", + ) + seq_parser.add_argument( + "--logout", + required=False, + default="sites.json", + type=str, + help="Path to output file", + ) + + density_parser = subparsers.add_parser("density") + density_subparser = density_parser.add_subparsers(dest="source", required=True) + xray_parser = density_subparser.add_parser("xray") + xray_parser.add_argument( + "--mtzin", required=True, type=str, help="Path to mtz file" + ) + xray_parser.add_argument( + "--modelin", + required=True, + type=str, + help="Path to a model in PDB or CIF format", + ) + xray_parser.add_argument( + "--preddirin", + required=False, + type=str, + help="Path to a model in PDB or CIF format", + ) + xray_parser.add_argument( + "--logout", + required=False, + default="sites.json", + type=str, + help="Path to output file", + ) + xray_parser.add_argument( + "--modeltype", + required=True, + choices=[type.name for type in ModelType], + help="Binary or Multiclass model", + ) + xray_parser.add_argument( + "--searchtype", + required=True, + choices=["protein", "glycan"], + help="Search for protein or glycan, only used if modeltype is multiclass", + ) + xray_parser.add_argument("--colin-fo", type=str, required=False, default="FP,SIGFP") + xray_parser.add_argument( + "--colin-fwt", type=str, required=False, default="FWT,PHWT" + ) + + em_parser = density_subparser.add_parser("em") + em_parser.add_argument("--mapin", type=str, required=True) + em_parser.add_argument( + "--modelin", + required=True, + type=str, + help="Path to a model in PDB or CIF format", + ) + em_parser.add_argument( + "--logout", + required=False, + default="sites.json", + type=str, + help="Path to output file", + ) + em_parser.add_argument( + "--preddirin", + required=False, + type=str, + help="Path to a model in PDB or CIF format", + ) + em_parser.add_argument( + "--modeltype", + required=True, + choices=[type.name for type in ModelType], + help="Binary or Multiclass model", + ) + em_parser.add_argument( + "--searchtype", + required=True, + choices=["protein", "glycan"], + help="Search for protein or glycan, only used if modeltype is multiclass", + ) + + args = parser.parse_args() + if args.mode == "seq": + sequence_find(args) + elif args.mode == "density": + density_find(args) diff --git a/package/src/sails/glycosylate.py b/package/src/sails/glycosylate.py index eadbc50..768b580 100644 --- a/package/src/sails/glycosylate.py +++ b/package/src/sails/glycosylate.py @@ -8,13 +8,24 @@ from typing import Tuple, List import gemmi -from sails import interface, n_glycosylate, c_glycosylate, o_mannosylate, __version__ +from sails import ( + interface, + n_glycosylate, + c_glycosylate, + o_mannosylate, + __version__, + auto_glycosylate, + glycosylate_site, +) +from .prediction.model import ModelType +from .prediction.predict import predict_map class Type(enum.IntEnum): n_glycosylate = 1 c_glycosylate = 2 o_mannosylate = 3 + auto = 4 def __str__(self): return self.name @@ -37,12 +48,40 @@ def map_type_to_function(type: Type): if type == Type.o_mannosylate: return o_mannosylate + if type == Type.auto: + return auto_glycosylate + raise TypeError("Type not found") +def read_prediction_dir( + path: Path | str, model_type: ModelType +) -> gemmi.FloatGrid | Tuple[gemmi.FloatGrid, gemmi.FloatGrid]: + path = Path(path) + glycan_path = path / "sails-glycan.map" + protein_path = path / "sails-protein.map" + + if not glycan_path.exists(): + raise FileNotFoundError(glycan_path) + + if model_type == ModelType.multiclass: + if not protein_path.exists(): + raise FileNotFoundError(protein_path) + + glycan_map = gemmi.read_ccp4_map(str(glycan_path)) + + if model_type == ModelType.multiclass: + protein_map = gemmi.read_ccp4_map(str(protein_path)) + return glycan_map.grid, protein_map.grid + return glycan_map.grid + + def glycosylate_xtal( structure: gemmi.Structure | Path | str, mtz: gemmi.Mtz | Path | str, + preddirin: Path | str, + chain: str, + seqid: int | str, cycles: int, f: str, sigf: str, @@ -68,8 +107,48 @@ def glycosylate_xtal( sails_mtz = interface.get_sails_mtz(mtz, f, sigf, fwt, phwt) resource = importlib.resources.files("sails").joinpath("data") - func = map_type_to_function(type) - result = func(sails_structure, sails_mtz, cycles, str(resource), verbose) + if chain and seqid: + result = glycosylate_site( + sails_structure, + sails_mtz, + chain, + int(seqid), + cycles, + str(resource), + verbose, + ) + return ( + interface.extract_sails_structure(result.structure), + interface.extract_sails_mtz(result.mtz), + json.loads(result.log), + result.snfgs, + ) + + if type == Type.auto: + if preddirin: + predictions = read_prediction_dir( + preddirin, model_type=ModelType.multiclass + ) + else: + predictions = predict_map( + "multiclass", mtz, "output", nthreads=8, save_map=True + ) + glycan, protein = predictions + sails_glycan = interface.get_sails_map(glycan) + sails_protein = interface.get_sails_map(protein) + + result = auto_glycosylate( + sails_structure, + sails_mtz, + sails_glycan, + sails_protein, + cycles, + str(resource), + verbose, + ) + else: + func = map_type_to_function(type) + result = func(sails_structure, sails_mtz, cycles, str(resource), verbose) return ( interface.extract_sails_structure(result.structure), @@ -82,6 +161,10 @@ def glycosylate_xtal( def glycosylate_em( structure: gemmi.Structure | Path | str, map: gemmi.Ccp4Map | gemmi.FloatGrid | Path | str, + preddirin: Path | str, + resolution: float, + chain: str, + seqid: int | str, cycles: int, type: Type = Type.n_glycosylate, verbose: bool = False, @@ -90,8 +173,51 @@ def glycosylate_em( sails_grid = interface.get_sails_map(map) resource = importlib.resources.files("sails").joinpath("data") - func = map_type_to_function(type) - result = func(sails_structure, sails_grid, cycles, str(resource), verbose) + if chain and seqid: + result = glycosylate_site( + sails_structure, + sails_grid, + resolution, + chain, + int(seqid), + cycles, + str(resource), + verbose, + ) + return ( + interface.extract_sails_structure(result.structure), + json.loads(result.log), + result.snfgs, + ) + + if type == Type.auto: + if preddirin: + predictions = read_prediction_dir( + preddirin, model_type=ModelType.multiclass + ) + else: + predictions = predict_map( + "multiclass", map, "output", nthreads=8, save_map=True + ) + glycan, protein = predictions + sails_glycan = interface.get_sails_map(glycan) + sails_protein = interface.get_sails_map(protein) + + result = auto_glycosylate( + sails_structure, + sails_grid, + resolution, + sails_glycan, + sails_protein, + cycles, + str(resource), + verbose, + ) + else: + func = map_type_to_function(type) + result = func( + sails_structure, sails_grid, resolution, cycles, str(resource), verbose + ) return ( interface.extract_sails_structure(result.structure), @@ -160,9 +286,19 @@ def save_snfgs(snfgs: dict, snfg_path: Path): def xray(args): labels = get_column_labels(args.colin_fo, args.colin_fwt) - cycles = args.cycles if args.type == Type.n_glycosylate else 1 + cycles = ( + args.cycles if args.type == Type.n_glycosylate or args.type == Type.auto else 1 + ) structure, mtz, log, snfgs = glycosylate_xtal( - args.modelin, args.mtzin, cycles, *labels, args.type, args.v + args.modelin, + args.mtzin, + args.preddirin, + args.chain, + args.seqid, + cycles, + *labels, + args.type, + args.v, ) if args.snfgout: @@ -175,9 +311,19 @@ def xray(args): def em(args): - cycles = args.cycles if args.type == Type.n_glycosylate else 1 + cycles = ( + args.cycles if args.type == Type.n_glycosylate or args.type == Type.auto else 1 + ) structure, log, snfgs = glycosylate_em( - args.modelin, args.mapin, cycles, args.type, args.v + args.modelin, + args.mapin, + args.preddirin, + args.resolution, + args.chain, + args.seqid, + cycles, + args.type, + args.v, ) structure.make_mmcif_block().write_file(args.modelout) save_log(log, args) @@ -207,16 +353,19 @@ def parse_args(): parent = argparse.ArgumentParser(add_help=False) group = parent.add_argument_group("Required arguments for all modes") group.add_argument("-v", action=argparse.BooleanOptionalAction, default=False) - group.add_argument("-modelin", type=str, required=True) + group.add_argument("--modelin", type=str, required=True) + group.add_argument("--preddirin", type=str, required=False) group.add_argument( - "-modelout", type=str, required=False, default="sails-model-out.cif" + "--modelout", type=str, required=False, default="sails-model-out.cif" ) - group.add_argument("-logout", type=str, default="sails-log.json") - group.add_argument("-snfgout", type=str) - group.add_argument("-cycles", type=int, required=False, default=2) + group.add_argument("--logout", type=str, default="sails-log.json") + group.add_argument("--snfgout", type=str) + group.add_argument("--cycles", type=int, required=False, default=2) group.add_argument( - "-type", type=Type.from_string, choices=list(Type), default=Type.n_glycosylate + "--type", type=Type.from_string, choices=list(Type), default=Type.auto ) + group.add_argument("--chain", type=str, required=False) + group.add_argument("--seqid", type=str, required=False) formatter = argparse.ArgumentDefaultsHelpFormatter xray_parser = subparsers.add_parser( @@ -225,17 +374,18 @@ def parse_args(): xray_parser_group = xray_parser.add_argument_group( "Required arguments in X-ray mode" ) - xray_parser_group.add_argument("-mtzin", type=str, required=True) + xray_parser_group.add_argument("--mtzin", type=str, required=True) xray_parser_group.add_argument( - "-mtzout", type=str, required=False, default="sails-refln-out.mtz" + "--mtzout", type=str, required=False, default="sails-refln-out.mtz" ) xray_parser_group.add_argument( - "-colin-fo", type=str, required=False, default="FP,SIGFP" + "--colin-fo", type=str, required=False, default="FP,SIGFP" ) - xray_parser_group.add_argument("-colin-fwt", type=str, required=False, default="") + xray_parser_group.add_argument("--colin-fwt", type=str, required=False, default="") em_parser = subparsers.add_parser("em", parents=[parent], formatter_class=formatter) em_parser_group = em_parser.add_argument_group("Required arguments in EM mode") - em_parser_group.add_argument("-mapin", type=str, required=True) + em_parser_group.add_argument("--mapin", type=str, required=True) + em_parser_group.add_argument("--resolution", type=float, required=True) return parser.parse_args() diff --git a/package/src/sails/install.py b/package/src/sails/install.py new file mode 100644 index 0000000..36246ed --- /dev/null +++ b/package/src/sails/install.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024 Jordan Dialpuri, Jon Agirre, Kathryn Cowtan, Paul Bond and University of York. All rights reserved + +import site +import os +import argparse +import enum +from pathlib import Path +from .__version__ import __version__ +from .logs import setup_logging +from .prediction import model +import logging + + +class InstallLocation(enum.Enum): + site_packages = 0 + ccp4 = 1 + + +def clibd_error_msg(): + print("""CCP4 Environment Variable - CLIBD is not found. + You can try sourcing it: + Ubuntu - source /opt/xtal/ccp4-X.X/bin/ccp4.setup-sh + MacOS - source /Applications/ccp4-X.X/bin/ccp4.setup-sh + """) + + +def install_model(type: model.ModelType, location: str, reinstall: bool) -> bool: + logging.info(f"Installing {type.name} model to {location}") + if InstallLocation[location] == InstallLocation.ccp4: + clibd = os.environ.get("CLIBD", "") + if not os.path.exists(clibd): + clibd_error_msg() + return False + + model.download_model(type=type, folder=clibd, reinstall=reinstall) + return True + + if InstallLocation[location] == InstallLocation.site_packages: + site_packages_dir = site.getsitepackages() + if not site_packages_dir: + raise RuntimeError( + "Failed to get site packages directory, ensure you in a virtual environment" + ) + first_sitepackages = Path(site_packages_dir[0]) + model.download_model(type=type, folder=first_sitepackages, reinstall=reinstall) + # download_database(folder=first_sitepackages, reinstall=reinstall) + return True + return False + + +def run(): + setup_logging() + output_list = ["ccp4", "site_packages"] + + parser = argparse.ArgumentParser(description="nucleofind Install") + parser.add_argument( + "-m", "--model", choices=[type.name for type in model.ModelType], required=False + ) + parser.add_argument( + "-o", + "--output", + choices=[location.name for location in InstallLocation], + required=False, + default=output_list[1], + ) + parser.add_argument("--update", required=False, action="store_true") + parser.add_argument("-v", "--version", action="version", version=__version__) + + args = parser.parse_args() + + if not args.model: + print("Please specify a model you wish to download") + return + + install_model( + type=model.ModelType[args.model], location=args.output, reinstall=args.update + ) diff --git a/package/src/sails/interface.py b/package/src/sails/interface.py index 3763fcb..8c3000e 100644 --- a/package/src/sails/interface.py +++ b/package/src/sails/interface.py @@ -122,7 +122,7 @@ def extract_gemmi_structure(structure: gemmi.Structure) -> sails.Structure: ) ) om = sails.Model() - om.name = structure[0].name + om.num = structure[0].num for chain in structure[0]: oc = sails.Chain() oc.name = chain.name @@ -150,6 +150,36 @@ def extract_gemmi_structure(structure: gemmi.Structure) -> sails.Structure: return os +def extract_sails_atom_address(atom_address: sails.AtomAddress): + oa = gemmi.AtomAddress() + oa.chain_name = atom_address.chain_name + oa.atom_name = atom_address.atom_name + + oseqid = gemmi.SeqId( + atom_address.res_id.seqid.num(), atom_address.res_id.seqid.icode() + ) + + oresid = gemmi.ResidueId() + oresid.seqid = oseqid + oresid.name = atom_address.res_id.name + oa.res_id = oresid + + return oa + + +def extract_sails_connections(connections: sails.Connections): + connection_list = gemmi.ConnectionList() + + for connection in connections: + oconnection = gemmi.Connection() + oconnection.type = gemmi.ConnectionType[connection.type.__name__] + oconnection.partner1 = extract_sails_atom_address(connection.partner1) + oconnection.partner2 = extract_sails_atom_address(connection.partner2) + + connection_list.append(oconnection) + return connection_list + + def extract_sails_structure(structure: sails.Structure) -> gemmi.Structure: os = gemmi.Structure() om = gemmi.Model("1") @@ -178,6 +208,8 @@ def extract_sails_structure(structure: sails.Structure) -> gemmi.Structure: cell = structure.cell() os.cell = gemmi.UnitCell(cell.a, cell.b, cell.c, cell.alpha, cell.beta, cell.gamma) + os.spacegroup_hm = structure.spacegroup_hm + os.connections = extract_sails_connections(structure.connections) return os diff --git a/package/src/sails/logs.py b/package/src/sails/logs.py new file mode 100644 index 0000000..612af11 --- /dev/null +++ b/package/src/sails/logs.py @@ -0,0 +1,9 @@ +import logging +import logging.config + + +def setup_logging(): + """Setup basic logging configuration""" + logging.basicConfig( + level=logging.INFO, format="%(asctime)s %(levelname)s - %(message)s" + ) diff --git a/package/src/sails/prediction/__init__.py b/package/src/sails/prediction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/package/src/sails/prediction/arguments.py b/package/src/sails/prediction/arguments.py new file mode 100644 index 0000000..5268917 --- /dev/null +++ b/package/src/sails/prediction/arguments.py @@ -0,0 +1,77 @@ +import argparse +from types import SimpleNamespace + +from sails.__version__ import __version__ +from .model import ModelType + + +def parse_arguments() -> SimpleNamespace: + """Parse command line arguments""" + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + help="Model selection", + choices=[type.name for type in ModelType], + required=False, + ) + parser.add_argument("-i", "--input", help="Input mtz", required=True) + parser.add_argument( + "-o", + "--output", + help="Output directory, if does not exist it will be created model", + default="sails-output", + required=False, + ) + parser.add_argument("-r", "--resolution", nargs="?", help="Resolution cutoff") + parser.add_argument( + "-n", + "--nthreads", + nargs="?", + default=None, + type=int, + help="Number of threads to use", + ) + parser.add_argument( + "--amplitude", "-f", nargs="?", help="Name of amplitude column in MTZ, e.g. FWT" + ) + parser.add_argument( + "--phase", "-phi", nargs="?", help="Name of phase column in MTZ, e.g. PHWT" + ) + parser.add_argument( + "--overlap", + nargs="?", + help="Amount of overlap to use", + default=None, + type=int, + ) + parser.add_argument( + "--use-symmetry", + action=argparse.BooleanOptionalAction, + default=False, + help="Compute predictions for the entire unit cell", + ) + parser.add_argument( + "--variance", action=argparse.BooleanOptionalAction, help="Output variance map" + ) + parser.add_argument( + "--raw", + action=argparse.BooleanOptionalAction, + help="Output raw map (no argmax)", + ) + parser.add_argument( + "--gpu", action=argparse.BooleanOptionalAction, help="Use GPU (experimental)" + ) + parser.add_argument( + "--debug", action=argparse.BooleanOptionalAction, help="Turn on debug logging" + ) + parser.add_argument( + "--silent", + action=argparse.BooleanOptionalAction, + default=False, + help="Turn off progress bar", + ) + parser.add_argument("--model_path", nargs="?", help="Path to model (development)") + parser.add_argument("-v", "--version", action="version", version=__version__) + args = vars(parser.parse_args()) + return SimpleNamespace(**args) diff --git a/package/src/sails/prediction/config.py b/package/src/sails/prediction/config.py new file mode 100644 index 0000000..3356070 --- /dev/null +++ b/package/src/sails/prediction/config.py @@ -0,0 +1,25 @@ +import dataclasses +import enum + + +@dataclasses.dataclass +class Configuration: + """Configuration for Sails""" + + use_gpu: bool = False + n_threads: int | None = None + disable_progress_bar: bool = True + compute_entire_unit_cell: bool = True + compute_variance: bool = False + use_raw_values: bool = False + spacing: float = 0.7 + box_size: int = 128 + channels: int = 2 + overlap: int = 64 + + +class MapType(enum.Enum): + """Map types for sails, i.e. model will output 1 for protein...""" + + glycan: int = 1 + protein: int = 2 diff --git a/package/src/sails/prediction/errors.py b/package/src/sails/prediction/errors.py new file mode 100644 index 0000000..e5b849f --- /dev/null +++ b/package/src/sails/prediction/errors.py @@ -0,0 +1,29 @@ +import logging +from typing import List + + +def show_missing_model_error(): + """Show error when no models are found""" + logging.critical(""" + No models have been found in either site_packages or CCP4/lib/data. + You can install models using the command: + sails-install -m {binary,multiclass} + """) + + +def show_missing_specified_model_error(model_name: str): + """Show error when model with specified name is not found""" + logging.critical(f""" + No model with the name {model_name} has been found in either site_packages or CCP4/lib/data.""") + + +def show_multiple_model_error(model_names: List[str]): + """Show warning when multiple models are found""" + multiple_model_names = "" + for model_name in model_names: + multiple_model_names += f"\t-m {model_name}\n" + + logging.warning(f""" + Multiple models have been found in either site_packages or CCP4/lib/data. + Please specify either: + {multiple_model_names}""") diff --git a/package/src/sails/prediction/grid_tools.py b/package/src/sails/prediction/grid_tools.py new file mode 100644 index 0000000..7347503 --- /dev/null +++ b/package/src/sails/prediction/grid_tools.py @@ -0,0 +1,87 @@ +import gemmi +import numpy as np +from typing import Tuple, List + +from .config import Configuration + + +def interpolate_grid( + grid: gemmi.FloatGrid, configuration: Configuration +) -> Tuple[np.ndarray, gemmi.Transform]: + """Interpolate grid to 0.7A grid spacing surrounding the unit cell and return interpolated grid and transform.""" + if configuration.compute_entire_unit_cell: + extent = gemmi.FractionalBox() + extent.extend(gemmi.Fractional(0, 0, 0)) + extent.extend(gemmi.Fractional(1, 1, 1)) + else: + extent = gemmi.find_asu_brick(grid.spacegroup).get_extent() + + box = grid.unit_cell.orthogonalize_box(extent) + margin = configuration.spacing * (configuration.box_size // 2) + box.add_margin(margin) + size = box.get_size() + numx = -( + -int(size.x / configuration.spacing) + // configuration.overlap + * configuration.overlap + ) + numy = -( + -int(size.y / configuration.spacing) + // configuration.overlap + * configuration.overlap + ) + numz = -( + -int(size.z / configuration.spacing) + // configuration.overlap + * configuration.overlap + ) + array = np.zeros((numx, numy, numz), dtype=np.float32) + scale = gemmi.Mat33(configuration.spacing * np.eye(3)) + transform: gemmi.Transform = gemmi.Transform(scale, box.minimum) + grid.interpolate_values(array, transform) + return array, transform + + +def precompute_slices(grid_shape: np.ndarray, overlap: int = 16) -> List[List[int]]: + """Precompute indices of slices to run inference on.""" + slices = [] + + for i in range(0, grid_shape[0] - overlap, overlap): + for j in range(0, grid_shape[1] - overlap, overlap): + for k in range(0, grid_shape[2] - overlap, overlap): + slices.append([i, j, k]) + return slices + + +def reinterpolate_grid( + work_array: np.ndarray, + transform: gemmi.Transform, + template_grid: gemmi.FloatGrid, + compute_entire_unit_cell: bool = True, +) -> gemmi.FloatGrid: + """Reinterpolate grid to original unit cell.""" + + output_grid = gemmi.FloatGrid() + output_grid.spacegroup = template_grid.spacegroup + output_grid.set_unit_cell(template_grid.unit_cell) + + grid_spacing = 0.7 + output_grid.set_size(*template_grid.shape) + size_x = work_array.shape[0] * grid_spacing + size_y = work_array.shape[1] * grid_spacing + size_z = work_array.shape[2] * grid_spacing + + array_cell = gemmi.UnitCell(size_x, size_y, size_z, 90, 90, 90) + array_grid = gemmi.FloatGrid(work_array, array_cell) + + if compute_entire_unit_cell: + grid_iterable = output_grid + else: + grid_iterable = output_grid.masked_asu() + + for point in grid_iterable: + position = output_grid.point_to_position(point) - gemmi.Position(transform.vec) + point.value = array_grid.interpolate_value(position) + + output_grid.symmetrize_max() + return output_grid diff --git a/package/src/sails/prediction/load.py b/package/src/sails/prediction/load.py new file mode 100644 index 0000000..e9e4791 --- /dev/null +++ b/package/src/sails/prediction/load.py @@ -0,0 +1,75 @@ +from pathlib import Path +import gemmi +import numpy as np +from typing import List +from .util import find_map_coefficients, check_density_path +import onnxruntime as rt +import sys +import logging + + +def load_mtz( + path: Path | str, + column_names: List[str] | None, + resolution_cutoff: float | None, +) -> gemmi.FloatGrid: + """Load MTZ file and transform to map with 0.7A grid spacing and with resolution cutoff if specified.""" + mtz = gemmi.read_mtz_file(str(path)) + if None in column_names: + logging.warning( + "No map coefficients were specified, NucleoFind will try and find some but they may be wrong." + ) + column_names = find_map_coefficients(mtz) + + res = mtz.resolution_high() + spacing = 0.7 + sample_rate = res / spacing + grid = mtz.transform_f_phi_to_map(*column_names, sample_rate=sample_rate) + grid.normalize() + if resolution_cutoff: + data = np.array(mtz, copy=False) + mtz.set_data(data[mtz.make_d_array() >= resolution_cutoff]) + return grid + + +def load_map(path: Path | str) -> gemmi.FloatGrid: + """Load map file and normalize""" + map = gemmi.read_ccp4_map(str(path)) + grid = map.grid + grid.normalize() + return grid + + +def load_density( + density_path: Path | str, + column_names: List[str] | None, + resolution_cutoff: float | None, +) -> gemmi.FloatGrid: + """Load density from MTZ file, or map file""" + density_path = check_density_path(density_path) + + if density_path.suffix == ".mtz": + return load_mtz(density_path, column_names, resolution_cutoff) + else: + return load_map(density_path) + + +def load_onnx_model( + model_path: Path | str, use_gpu: bool = True +) -> rt.InferenceSession: + """Load ONNX model from model_path""" + providers = ["CPUExecutionProvider"] + if use_gpu: + providers.insert(0, "CUDAExecutionProvider") + sess_options = rt.SessionOptions() + sess_options.intra_op_num_threads = 1 + try: + return rt.InferenceSession( + str(model_path), providers=providers, sess_options=sess_options + ) + except OSError: + logging.critical( + "This model is corrupted, perhaps due to an incomplete download. Try downloading it again with " + "nucleofind-install -m TYPE --reinstall" + ) + sys.exit(1) diff --git a/package/src/sails/prediction/model.py b/package/src/sails/prediction/model.py new file mode 100644 index 0000000..76fe6e8 --- /dev/null +++ b/package/src/sails/prediction/model.py @@ -0,0 +1,224 @@ +import os +import site +import sys +import urllib +from pathlib import Path +import logging +from types import SimpleNamespace +from typing import Tuple, List + +import requests +from enum import Enum +import re +import hashlib + +from .errors import ( + show_missing_model_error, + show_multiple_model_error, + show_missing_specified_model_error, +) + + +class ModelType(Enum): + """Types of sails Model available""" + + binary = 1 + multiclass = 2 + + +def calculate_sha256(file_path: Path): + """Calculate SHA256 hash of file""" + logging.debug("Calculating SHA256 hash for %s", file_path) + with open(file_path, "rb") as f: + file_hash = hashlib.sha256() + while chunk := f.read(4096): + file_hash.update(chunk) + return file_hash.hexdigest() + + +def calculate_size(file_path: Path) -> int: + """Calculate size of file""" + logging.debug("Calculating size for %s", file_path) + return file_path.stat().st_size + + +def get_latest_model_metadata(type: ModelType, latest_model: str) -> Tuple[str, str]: + """Get latest model metadata from HuggingFace""" + url = f"https://huggingface.co/dialpuri/sails-{type.name}/raw/main/{latest_model}" + response = requests.get(url) + text = response.text + sha_match = re.search(r"sha256:([a-f0-9]+)", text) + if not sha_match: + raise RuntimeError("Failed to get SHA256 hash from model metadata") + sha256 = sha_match.group(1) + + size_match = re.search(r"size ([0-9]+)", text) + if not size_match: + raise RuntimeError("Failed to get size from model metadata") + size = size_match.group(1) + return sha256, size + + +def is_model_valid(type: ModelType, model_path: Path, latest_model: str) -> bool: + """Compare current model hash with latest model hash""" + current_model_hash = calculate_sha256(model_path) + + latest_model_hash, latest_model_size = get_latest_model_metadata(type, latest_model) + if latest_model_hash != current_model_hash: + logging.info("Latest model and current modal checksum failed") + return False + return True + + +def get_latest_model(type: ModelType) -> str: + """Query the HuggingFace API to get URL for latest model""" + base_url = "https://huggingface.co/api/models/Dialpuri/Sails" + url = f"{base_url}-{type.name}" + logging.debug("Getting latest model for %s from %s", type.name, url) + response = requests.get(url) + json = response.json() + if not json: + raise RuntimeError("Failed to get model URL") + + siblings = json.get("siblings", None) + if not siblings: + raise RuntimeError("Failed to get siblings from model") + + possible_models = [] + for filename in siblings: + file = filename.get("rfilename", "") + if file.endswith(".onnx"): + possible_models.append(file) + # Get latest model out of list based on date + possible_models = sorted(possible_models, reverse=True) + latest_model = possible_models[0] + logging.debug("Latest model for %s is %s", type.name, latest_model) + return latest_model + + +def download_model( + type: ModelType, folder: Path, reinstall: bool = False, dry_run: bool = False +): + """Download model from HuggingFace""" + latest_model = get_latest_model(type) + sails_model_dir = folder / "sails_models" + sails_model_dir.mkdir(exist_ok=True) + model_path = sails_model_dir / f"sails-{type.name}.onnx" + + # Check if model already exists and is the latest version. + if model_path.exists() and not reinstall: + status = is_model_valid(type, model_path, latest_model) + if not status: + logging.warning( + "A model was found but did not pass the latest validation checks, it may be corrupted, or a newer version available. " + "To update the model, run `sails-install --update`" + ) + else: + logging.warning( + "Model already exists at %s, skipping download.", model_path + ) + return + + url = f"https://huggingface.co/dialpuri/Sails-{type.name}/resolve/main/{latest_model}?download=true" + logging.debug("Downloading model from %s", url) + if not dry_run: + urllib.request.urlretrieve(url, model_path) + + if not is_model_valid(type, model_path, latest_model): + logging.error("Model verification failed, model may be corrupted.") + + +def find_all_potential_models(): + """Find all potential models in site-packages and CCP4""" + model_extension = "*.onnx" + + potential_models = [] + + for pkg in site.getsitepackages(): + model_directory = Path(pkg) / "sails_models" + models = list(model_directory.glob(model_extension)) + potential_models += models + + clibd = Path(os.environ.get("CLIBD", "")) + if not clibd.exists() and not potential_models: + logging.warning( + """CCP4 Environment Variable - CLIBD is not found. + You can try sourcing it: + Ubuntu - source /opt/xtal/ccp4-X.X/bin/ccp4.setup-sh + MacOS - source /Applications/ccp4-X.X/bin/ccp4.setup-sh + """ + ) + return + + ccp4_model_path = clibd / "sails_models" + if not ccp4_model_path.exists() and not potential_models: + show_missing_model_error() + return + + potential_models += list(ccp4_model_path.glob(model_extension)) + + if not potential_models: + show_missing_model_error() + sys.exit(1) + + return [Path(x) for x in potential_models] + + +def extract_model_names(models: List[Path]) -> List[str]: + """Extract model names from model paths""" + model_names = [] + for model in models: + match = re.search(r"sails-(\w+).onnx", model.name) + if not match: + raise RuntimeError( + "Failed to extract model name from model path, have the models been renamed? Please report this issue on GitHub." + ) + name = match.group(1) + model_names.append(name) + return model_names + + +def find_model(model: ModelType | str | None) -> Tuple[Path, ModelType] | None: + """Search through site-packages and CCP4/lib/data for a potential model""" + potential_models = find_all_potential_models() + if not potential_models: + sys.exit(1) + + model_names = extract_model_names(potential_models) + if not model and len(potential_models) == 1: + return Path(potential_models[0]), ModelType[model_names[0]] + + if not model: + show_multiple_model_error(model_names) + sys.exit(1) + + if isinstance(model, ModelType): + specified_model_name = model.name + else: + specified_model_name = model + + for name in model_names: + if name == specified_model_name: + return Path(potential_models[model_names.index(name)]), ModelType[name] + + show_missing_specified_model_error(specified_model_name) + sys.exit(1) + + +def get_model_config(model_path: Path, overlap: int | None) -> SimpleNamespace: + """Get model configuration from model type""" + model_type = model_path.stem.removeprefix("sails-") + if model_type not in ModelType.__members__: + raise RuntimeError(f"Invalid model type - {model_type}") + model_type = ModelType[model_type] + match model_type: + case ModelType.binary: + return SimpleNamespace( + box_size=128, overlap=64 if overlap is None else overlap + ) + case ModelType.multiclass: + return SimpleNamespace( + box_size=128, overlap=64 if overlap is None else overlap, channels=3 + ) + case _: + raise RuntimeError(f"Invalid model type - {model_type}") diff --git a/package/src/sails/prediction/predict.py b/package/src/sails/prediction/predict.py new file mode 100644 index 0000000..60c2900 --- /dev/null +++ b/package/src/sails/prediction/predict.py @@ -0,0 +1,236 @@ +import functools +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List, Tuple + +import gemmi +import numpy as np +from tqdm import tqdm + +from .load import load_density, load_onnx_model +from .grid_tools import interpolate_grid, reinterpolate_grid, precompute_slices +from .arguments import parse_arguments +from .model import find_model, get_model_config, ModelType +from .save import save_grid +from ..logs import setup_logging +from .config import Configuration, MapType + + +class Sails: + def __init__(self, model_path: Path | str, configuration: Configuration): + self.model_path = model_path + self.configuration = configuration + + self.model = None + self.predicted_grids = {} + + def _process_sample( + self, + input_name: str, + output_shape: Tuple[int], + box_size: int, + array: np.ndarray[np.float32], + translation: Tuple[int, int, int], + ) -> Tuple[np.ndarray, Tuple[int, int, int]]: + """Perform inference on a single sample of shape (1, box_size, box_size, box_size, 1) and return an array of shape + (box_size, box_size, box_size, output_channels) and the translation (for putting back into an array). + """ + i, j, k = translation + input_sub = array[i : i + box_size, j : j + box_size, k : k + box_size] + input_sub = input_sub[np.newaxis, ..., np.newaxis].astype(np.float32) + + return np.array(self.model.run(None, {input_name: input_sub})).reshape( + output_shape + ), translation + + def _run_prediction(self, work_grid: np.ndarray) -> np.ndarray: + """Run prediction on work_grid and calculate the average predicted grid""" + work_grid_shape = np.array(work_grid.shape) + slices = precompute_slices(work_grid_shape, overlap=self.configuration.overlap) + box_size = self.configuration.box_size + + total_array = np.zeros( + (*work_grid_shape, self.configuration.channels), dtype=np.float32 + ) + count_array = np.zeros_like(total_array, dtype=np.float32) + + # Variance arrays for Welch's one pass variance method + variance_mean = np.zeros_like(total_array, dtype=np.float32) + variance_m2 = np.zeros_like(total_array, dtype=np.float32) + + channels = self.configuration.channels + input_name = self.model.get_inputs()[0].name + output_shape = (box_size, box_size, box_size, channels) + process_sample_worker = functools.partial( + self._process_sample, + input_name, + output_shape, + box_size, + work_grid, + ) + + miniters = 1_000 if len(slices) > 10_000 else 1 + max_workers = self.configuration.n_threads + if max_workers == 1: + results = list( + tqdm( + map(process_sample_worker, slices), + total=len(slices), + desc="Predicting", + miniters=miniters, + disable=self.configuration.disable_progress_bar, + ) + ) + else: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list( + tqdm( + executor.map(process_sample_worker, slices), + total=len(slices), + desc="Predicting", + miniters=miniters, + disable=self.configuration.disable_progress_bar, + ) + ) + + ones = np.ones(channels) + for result in tqdm(results, desc="Processing results"): + predicted_sub, (i, j, k) = result + box_slice = ( + slice(i, i + box_size), + slice(j, j + box_size), + slice(k, k + box_size), + slice(None), + ) + total_array[box_slice] += predicted_sub + count_array[box_slice] += ones + + if self.configuration.compute_variance: + delta_variance = total_array[box_slice] - variance_mean[box_slice] + variance_mean[box_slice] += delta_variance / count_array[box_slice] + variance_m2[box_slice] += delta_variance * ( + total_array[box_slice] - variance_mean[box_slice] + ) + + predicted_array = total_array / count_array + if self.configuration.use_raw_values: + return predicted_array.astype(np.float32) + + if self.configuration.compute_variance: + variance_array = variance_m2 / (np.subtract(count_array, 1)) + return variance_array.astype(np.float32) + + argmax_array = np.argmax(predicted_array, axis=-1).squeeze() + return argmax_array.astype(np.float32) + + def predict( + self, + density_path: Path | str, + column_names: List[str] | List[None], + resolution_cutoff: float | None = None, + ): + """Run a sails prediction on specified density file. If density file is an MTZ, supply column names and an + optional resolution cutoff. If density file is a MAP, these will be ignored.""" + self.model = load_onnx_model(self.model_path, self.configuration.use_gpu) + input_grid = load_density(density_path, column_names, resolution_cutoff) + + work_grid, transform = interpolate_grid(input_grid, self.configuration) + predicted_array = self._run_prediction(work_grid) + + rounded_array = np.round(predicted_array) + for i in range(1, self.configuration.channels): + if self.configuration.use_raw_values or self.configuration.compute_variance: + index_array = predicted_array[:, :, :, i].astype(np.float32) + else: + index_array = (rounded_array == i).astype(np.float32) + + interpolated_index_array = reinterpolate_grid( + index_array, + transform, + input_grid, + self.configuration.compute_entire_unit_cell, + ) + self.predicted_grids[MapType(i)] = interpolated_index_array + + def save_grid(self, type: MapType, output_path: Path | str): + """Save the predicted grid to directory specified by output_path, with filename sails-{type}.map.""" + output_path = Path(output_path) + output_path.mkdir(exist_ok=True, parents=True) + logging.info(f"Saving grid of type {type} to {output_path}") + + suffix = ".map" + suffix = ".variance.map" if self.configuration.compute_variance else suffix + suffix = ".raw.map" if self.configuration.use_raw_values else suffix + + save_grid( + self.predicted_grids[type], + output_path / f"sails-{type.name}{suffix}", + ) + + def get_grid(self, type: MapType): + return self.predicted_grids[type] + + +def run(): + """Run prediction from command line arguments""" + setup_logging() + args = parse_arguments() + model_path, model = find_model(args.model) + model_configuration = get_model_config(model_path, args.overlap) + configuration = Configuration( + use_gpu=args.gpu, + disable_progress_bar=args.silent, + compute_entire_unit_cell=False, + use_raw_values=args.raw, + compute_variance=args.variance, + n_threads=args.nthreads, + **vars(model_configuration), + ) + sails = Sails(model_path, configuration) + sails.predict( + args.input, + [args.amplitude, args.phase], + ) + output_dir = Path(args.output) + sails.save_grid(MapType.glycan, output_dir) + if model == ModelType.multiclass: + sails.save_grid(MapType.protein, output_dir) + + +def predict_map( + model: str, + input: str, + output: str, + resolution: float = None, + amplitude: str = "FWT", + phase: str = "PHWT", + overlap: int = None, + nthreads: int = 1, + save_map: bool = False, +) -> gemmi.FloatGrid | Tuple[gemmi.FloatGrid, gemmi.FloatGrid]: + """Run prediction from Python""" + logging.info( + f"Running prediction with model {model}, input {input}, output {output}, resolution {resolution}, amplitude {amplitude}, phase {phase}, overlap {overlap}" + ) + + model = ModelType[model] + model_path, _ = find_model(model) + model_configuration = get_model_config(model_path, overlap) + configuration = Configuration( + use_gpu=False, + disable_progress_bar=False, + compute_entire_unit_cell=False, + n_threads=nthreads, + **vars(model_configuration), + ) + prediction = Sails(model_path, configuration=configuration) + prediction.predict(input, [amplitude, phase], resolution_cutoff=resolution) + if save_map: + prediction.save_grid(MapType.glycan, output) + if model == ModelType.multiclass: + prediction.save_grid(MapType.protein, output) + + if model == ModelType.multiclass: + return prediction.get_grid(MapType.glycan), prediction.get_grid(MapType.protein) + return prediction.get_grid(MapType.glycan) diff --git a/package/src/sails/prediction/save.py b/package/src/sails/prediction/save.py new file mode 100644 index 0000000..7cb62c7 --- /dev/null +++ b/package/src/sails/prediction/save.py @@ -0,0 +1,10 @@ +import gemmi +from pathlib import Path + + +def save_grid(grid: gemmi.FloatGrid, path: Path | str): + """Save grid to CCP4 map file.""" + map = gemmi.Ccp4Map() + map.grid = grid + map.update_ccp4_header() + map.write_ccp4_map(str(path)) diff --git a/package/src/sails/prediction/util.py b/package/src/sails/prediction/util.py new file mode 100644 index 0000000..174d14a --- /dev/null +++ b/package/src/sails/prediction/util.py @@ -0,0 +1,41 @@ +import logging +from pathlib import Path + +import gemmi +from typing import Tuple +import sys + + +def find_map_coefficients(mtz: gemmi.Mtz) -> Tuple[str, str]: + """Find F and P columns in MTZ file.""" + Fs = mtz.columns_with_type("F") + Ps = mtz.columns_with_type("P") + Fs = [F.label for F in Fs] + Ps = [P.label for P in Ps] + + if not Fs or not Ps: + logging.critical("No F and P columns found in MTZ file.") + sys.exit(1) + + if "FWT" in Fs and "PHWT" in Ps: + logging.warning("FWT and PHWT found, using them.") + return "FWT", "PHWT" + + F, P = Fs[0], Ps[0] + if len(Fs) != 1 or len(Ps) != 1: + logging.warning(f"Multiple F and P columns found. Using first set. {F=}, {P=}") + return F.label, P.label + + +def check_density_path(density_path): + """Check that density path is a valid type and exists.""" + allowed_extensions = [".mtz", ".map", ".ccp4", ".mrc", ".gz"] + density_path = Path(density_path) + if not density_path.exists(): + logging.critical(f"Density file {density_path} does not exist.") + sys.exit(1) + + if any(suffix not in allowed_extensions for suffix in density_path.suffixes): + logging.critical(f"Density file must be one of {allowed_extensions}") + sys.exit(1) + return density_path diff --git a/package/src/sails/snfg.py b/package/src/sails/snfg.py index d4b5ef0..bd35586 100644 --- a/package/src/sails/snfg.py +++ b/package/src/sails/snfg.py @@ -110,10 +110,10 @@ def create_single_snfg(args): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("-model", type=str, required=True) - parser.add_argument("-snfgout", type=str, required=True) - parser.add_argument("-chain", type=str, required=False) - parser.add_argument("-seqid", type=int, required=False) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--snfgout", type=str, required=True) + parser.add_argument("--chain", type=str, required=False) + parser.add_argument("--seqid", type=int, required=False) parser.add_argument("--all", action=argparse.BooleanOptionalAction, required=False) parser.add_argument( "--overwrite", action=argparse.BooleanOptionalAction, required=False diff --git a/package/src/sails/validate.py b/package/src/sails/validate.py new file mode 100644 index 0000000..15e90c2 --- /dev/null +++ b/package/src/sails/validate.py @@ -0,0 +1,129 @@ +import argparse +import json +import time + +from .__version__ import __version__ +import importlib +from sails import validate, validate_site, interface + +from .glycosylate import get_column_labels + + +def parse_args(): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="mode", required=True) + + parser.add_argument("--version", action="version", version=__version__) + + parent = argparse.ArgumentParser(add_help=False) + group = parent.add_argument_group("Required arguments for all modes") + group.add_argument("-v", action=argparse.BooleanOptionalAction, default=False) + group.add_argument("--modelin", type=str, required=True) + group.add_argument("--modelout", type=str, default="sails-validate.cif") + group.add_argument("--logout", type=str, default="sails-validate.log") + group.add_argument( + "--threshold", type=float, default=0.8, help="RSCC Threshold to use for removal" + ) + group.add_argument("--remove", action=argparse.BooleanOptionalAction, default=False) + group.add_argument("--print", action=argparse.BooleanOptionalAction, default=False) + + formatter = argparse.ArgumentDefaultsHelpFormatter + xray_parser = subparsers.add_parser( + "xray", parents=[parent], formatter_class=formatter + ) + xray_parser_group = xray_parser.add_argument_group( + "Required arguments in X-ray mode" + ) + xray_parser_group.add_argument("--mtzin", type=str, required=True) + xray_parser_group.add_argument( + "--colin-fo", type=str, required=False, default="FP,SIGFP" + ) + xray_parser_group.add_argument( + "--colin-fwt", type=str, required=False, default="FWT,PHWT" + ) + xray_parser_group.add_argument("--chain", type=str, required=False) + xray_parser_group.add_argument("--seqid", type=str, required=False) + + em_parser = subparsers.add_parser("em", parents=[parent], formatter_class=formatter) + em_parser_group = em_parser.add_argument_group("Required arguments in EM mode") + em_parser_group.add_argument("--mapin", type=str, required=True) + em_parser_group.add_argument("--resolution", type=float, required=True) + em_parser_group.add_argument( + "--score", choices=["q", "rscc"], required=False, default="q" + ) + + return parser.parse_args() + + +def xray(args): + sails_structure = interface.get_sails_structure(args.modelin) + resource = importlib.resources.files("sails").joinpath("data") + + labels = get_column_labels(args.colin_fo, args.colin_fwt) + sails_mtz = interface.get_sails_mtz(args.mtzin, *labels) + + if args.chain and args.seqid: + result = validate_site( + sails_structure, + sails_mtz, + args.chain, + args.seqid, + args.remove, + args.threshold, + str(resource), + ) + else: + result = validate( + sails_structure, sails_mtz, args.remove, args.threshold, str(resource) + ) + + structure = interface.extract_sails_structure(result.structure) + structure.make_mmcif_block().write_file(args.modelout) + log = json.loads(result.log) + + if args.print: + print(json.dumps(log, indent=4)) + + with open(args.logout, "w") as f: + json.dump(log, f, indent=4) + + +def em(args): + sails_structure = interface.get_sails_structure(args.modelin) + sails_grid = interface.get_sails_map(args.mapin) + resource = importlib.resources.files("sails").joinpath("data") + + result = validate( + sails_structure, + sails_grid, + args.resolution, + args.remove, + args.threshold, + args.score == "q", + str(resource), + ) + + structure = interface.extract_sails_structure(result.structure) + structure.make_mmcif_block().write_file(args.modelout) + log = json.loads(result.log) + + if args.print: + print(json.dumps(log, indent=4)) + + with open(args.logout, "w") as f: + json.dump(log, f, indent=4) + + +def run(): + t0 = time.time() + args = parse_args() + + if args.mode == "xray": + xray(args) + elif args.mode == "em": + em(args) + else: + raise RuntimeError("Unknown mode") + + t1 = time.time() + print(f"Sails Validate - Time Taken = {(t1 - t0)} seconds") diff --git a/package/tests/test_glycosylation.py b/package/tests/test_glycosylation.py index e139782..13000a0 100644 --- a/package/tests/test_glycosylation.py +++ b/package/tests/test_glycosylation.py @@ -18,7 +18,7 @@ def cglycan(data_base_path): s = gemmi.read_structure(str(s_path)) m = gemmi.read_mtz_file(str(m_path)) - return s, m, 1, "FP", "SIGFP", "", "", sails.Type.c_glycosylate + return s, m, "", "", "", 1, "FP", "SIGFP", "", "", sails.Type.c_glycosylate def test_xtal_cglycosylation(cglycan): @@ -43,26 +43,22 @@ def test_xtal_cglycosylation(cglycan): assert "entries" in cycle entries = cycle["entries"] - expected_key = "D-AMAN-1" + expected_key = "D-MAN-1" assert expected_key in entries assert len(entries.keys()) == 1 sugar = entries[expected_key] rscc_key = "rscc" rsr_key = "rsr" - dds_key = "dds" assert rscc_key in sugar assert rsr_key in sugar - assert dds_key in sugar rscc_score = sugar[rscc_key] rsr_score = sugar[rsr_key] - dds_score = sugar[dds_key] assert rscc_score > 0.7 assert rsr_score > 0.9 - assert dds_score < 0.75 # test snfg output assert 1 in snfgs