-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcast.cpp
More file actions
84 lines (73 loc) · 3.19 KB
/
cast.cpp
File metadata and controls
84 lines (73 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include "common.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/operator/cast_operators.hpp"
#include "duckdb/common/types/string_type.hpp"
#include "duckdb/common/types/vector.hpp"
#include "duckdb/function/cast/default_casts.hpp"
#include "mol_formats.hpp"
#include "types.hpp"
#include "umbra_mol.hpp"
#include <GraphMol/Descriptors/MolDescriptors.h>
#include <GraphMol/FileParsers/FileParsers.h>
#include <GraphMol/GraphMol.h>
#include <GraphMol/MolPickler.h>
#include <GraphMol/SmilesParse/SmartsWrite.h>
#include <GraphMol/SmilesParse/SmilesParse.h>
#include <GraphMol/SmilesParse/SmilesWrite.h>
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
namespace duckdb_rdkit {
// This enables the user to insert into a Mol column by just writing the SMILES
// Duckdb will try to convert the string to a rdkit mol
// This is consistent with the RDKit Postgres cartridge behavior
void VarcharToMol(Vector &source, Vector &result, idx_t count) {
UnaryExecutor::ExecuteWithNulls<string_t, string_t>(
source, result, count,
[&](string_t smiles, ValidityMask &mask, idx_t idx) {
try {
// this varchar is just a regular string, not a umbramol
// Try to see if it is a SMILES
auto mol = rdkit_mol_from_smiles(smiles.GetString());
auto umbra_mol = get_umbra_mol_string(*mol);
return StringVector::AddStringOrBlob(result, umbra_mol);
} catch (...) {
std::cout << "WARNING: could not create molecule from SMILES\n"
<< smiles.GetData() << std::endl;
// printf("WARNING: could not create molecule from SMILES %s\n",
// smiles.GetData());
mask.SetInvalid(idx);
return string_t();
}
});
}
bool VarcharToMolCast(Vector &source, Vector &result, idx_t count,
CastParameters ¶meters) {
VarcharToMol(source, result, count);
return true;
}
void MolToVarchar(Vector &source, Vector &result, idx_t count) {
UnaryExecutor::Execute<string_t, string_t>(
source, result, count, [&](string_t b_umbra_mol) {
// The input is a string_t coming from the duckdb internals.
// The extension recognizes that this string_t is an
// UmbraMol BLOB and will trigger this cast function.
// Therefore, this function expects that the input
// contains a string that has the format of umbra_mol_t.
auto umbra_mol = umbra_mol_t(b_umbra_mol);
auto bmol = umbra_mol.GetBinaryMol();
auto rdkit_mol = rdkit_binary_mol_to_mol(bmol);
auto smiles = rdkit_mol_to_smiles(*rdkit_mol);
return StringVector::AddString(result, smiles);
});
}
bool MolToVarcharCast(Vector &source, Vector &result, idx_t count,
CastParameters ¶meters) {
MolToVarchar(source, result, count);
return true;
}
void RegisterCasts(ExtensionLoader &loader) {
loader.RegisterCastFunction(LogicalType::VARCHAR, ::duckdb_rdkit::Mol(),
BoundCastInfo(VarcharToMolCast), 1);
loader.RegisterCastFunction(duckdb_rdkit::Mol(), LogicalType::VARCHAR,
BoundCastInfo(MolToVarcharCast), 1);
}
} // namespace duckdb_rdkit