|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -#include "vec/functions/array/function_array_cross_product.h" |
| 18 | +#include <gen_cpp/Types_types.h> |
19 | 19 |
|
| 20 | +#include "common/exception.h" |
| 21 | +#include "common/status.h" |
| 22 | +#include "runtime/primitive_type.h" |
| 23 | +#include "vec/columns/column.h" |
| 24 | +#include "vec/columns/column_array.h" |
| 25 | +#include "vec/columns/column_nullable.h" |
| 26 | +#include "vec/common/assert_cast.h" |
| 27 | +#include "vec/core/types.h" |
| 28 | +#include "vec/data_types/data_type.h" |
| 29 | +#include "vec/data_types/data_type_array.h" |
| 30 | +#include "vec/data_types/data_type_nullable.h" |
| 31 | +#include "vec/data_types/data_type_number.h" |
| 32 | +#include "vec/functions/array/function_array_utils.h" |
| 33 | +#include "vec/functions/function.h" |
20 | 34 | #include "vec/functions/simple_function_factory.h" |
| 35 | +#include "vec/utils/util.hpp" |
21 | 36 |
|
22 | 37 | namespace doris::vectorized { |
23 | 38 |
|
| 39 | +class FunctionArrayCrossProduct : public IFunction { |
| 40 | +public: |
| 41 | + using DataType = PrimitiveTypeTraits<TYPE_FLOAT>::DataType; |
| 42 | + using ColumnType = PrimitiveTypeTraits<TYPE_FLOAT>::ColumnType; |
| 43 | + |
| 44 | + static constexpr auto name = "cross_product"; |
| 45 | + String get_name() const override { return name; } |
| 46 | + static FunctionPtr create() { return std::make_shared<FunctionArrayCrossProduct>(); } |
| 47 | + size_t get_number_of_arguments() const override { return 2; } |
| 48 | + |
| 49 | + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { |
| 50 | + if (arguments.size() != 2) { |
| 51 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 52 | + "Invalid number of arguments for function {}", get_name()); |
| 53 | + } |
| 54 | + |
| 55 | + if (arguments[0]->get_primitive_type() != TYPE_ARRAY || |
| 56 | + arguments[1]->get_primitive_type() != TYPE_ARRAY) { |
| 57 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 58 | + "Arguments for function {} must be arrays", get_name()); |
| 59 | + } |
| 60 | + |
| 61 | + // return ARRAY<FLOAT> |
| 62 | + return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat32>()); |
| 63 | + } |
| 64 | + |
| 65 | + // strict semantics: do not allow NULL |
| 66 | + bool use_default_implementation_for_nulls() const override { return false; } |
| 67 | + |
| 68 | + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, |
| 69 | + uint32_t result, size_t input_rows_count) const override { |
| 70 | + const auto& arg1 = block.get_by_position(arguments[0]); |
| 71 | + const auto& arg2 = block.get_by_position(arguments[1]); |
| 72 | + |
| 73 | + auto col1 = arg1.column->convert_to_full_column_if_const(); |
| 74 | + auto col2 = arg2.column->convert_to_full_column_if_const(); |
| 75 | + |
| 76 | + if (col1->size() != col2->size()) { |
| 77 | + return Status::FatalError( |
| 78 | + fmt::format("function {} have different input array sizes: {} rows and {} rows", |
| 79 | + get_name(), col1->size(), col2->size())); |
| 80 | + } |
| 81 | + |
| 82 | + const ColumnArray* arr1 = nullptr; |
| 83 | + const ColumnArray* arr2 = nullptr; |
| 84 | + |
| 85 | + if (col1->is_nullable()) { |
| 86 | + if (col1->has_null()) { |
| 87 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 88 | + "First argument for function {} cannot be null", |
| 89 | + get_name()); |
| 90 | + } |
| 91 | + auto nullable1 = assert_cast<const ColumnNullable*>(col1.get()); |
| 92 | + arr1 = assert_cast<const ColumnArray*>(nullable1->get_nested_column_ptr().get()); |
| 93 | + } else { |
| 94 | + arr1 = assert_cast<const ColumnArray*>(col1.get()); |
| 95 | + } |
| 96 | + |
| 97 | + if (col2->is_nullable()) { |
| 98 | + if (col2->has_null()) { |
| 99 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 100 | + "Second argument for function {} cannot be null", |
| 101 | + get_name()); |
| 102 | + } |
| 103 | + auto nullable2 = assert_cast<const ColumnNullable*>(col2.get()); |
| 104 | + arr2 = assert_cast<const ColumnArray*>(nullable2->get_nested_column_ptr().get()); |
| 105 | + } else { |
| 106 | + arr2 = assert_cast<const ColumnArray*>(col2.get()); |
| 107 | + } |
| 108 | + |
| 109 | + const ColumnFloat32* float1 = nullptr; |
| 110 | + const ColumnFloat32* float2 = nullptr; |
| 111 | + |
| 112 | + if (arr1->get_data_ptr()->is_nullable()) { |
| 113 | + if (arr1->get_data_ptr()->has_null()) { |
| 114 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 115 | + "First argument for function {} cannot have null elements", |
| 116 | + get_name()); |
| 117 | + } |
| 118 | + auto nullable1 = assert_cast<const ColumnNullable*>(arr1->get_data_ptr().get()); |
| 119 | + float1 = assert_cast<const ColumnFloat32*>(nullable1->get_nested_column_ptr().get()); |
| 120 | + } else { |
| 121 | + float1 = assert_cast<const ColumnFloat32*>(arr1->get_data_ptr().get()); |
| 122 | + } |
| 123 | + |
| 124 | + if (arr2->get_data_ptr()->is_nullable()) { |
| 125 | + if (arr2->get_data_ptr()->has_null()) { |
| 126 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 127 | + "Second argument for function {} cannot have null elements", |
| 128 | + get_name()); |
| 129 | + } |
| 130 | + auto nullable2 = assert_cast<const ColumnNullable*>(arr2->get_data_ptr().get()); |
| 131 | + float2 = assert_cast<const ColumnFloat32*>(nullable2->get_nested_column_ptr().get()); |
| 132 | + } else { |
| 133 | + float2 = assert_cast<const ColumnFloat32*>(arr2->get_data_ptr().get()); |
| 134 | + } |
| 135 | + |
| 136 | + const auto* offset1 = |
| 137 | + assert_cast<const ColumnArray::ColumnOffsets*>(arr1->get_offsets_ptr().get()); |
| 138 | + const auto* offset2 = |
| 139 | + assert_cast<const ColumnArray::ColumnOffsets*>(arr2->get_offsets_ptr().get()); |
| 140 | + |
| 141 | + // prepare result data |
| 142 | + auto nested_res = ColumnFloat32::create(); |
| 143 | + auto& nested_data = nested_res->get_data(); |
| 144 | + nested_data.resize(3 * input_rows_count); |
| 145 | + auto offsets_res = ColumnArray::ColumnOffsets::create(); |
| 146 | + auto& offsets_data = offsets_res->get_data(); |
| 147 | + offsets_data.resize(input_rows_count); |
| 148 | + size_t current_offset = 0; |
| 149 | + |
| 150 | + size_t prev_offset1 = 0; |
| 151 | + size_t prev_offset2 = 0; |
| 152 | + for (ssize_t row = 0; row < input_rows_count; ++row) { |
| 153 | + ssize_t size1 = offset1->get_data()[row] - prev_offset1; |
| 154 | + ssize_t size2 = offset2->get_data()[row] - prev_offset2; |
| 155 | + |
| 156 | + if (size1 != size2) { |
| 157 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 158 | + "function {} have different input element sizes of array: {} and {}", |
| 159 | + get_name(), size1, size2); |
| 160 | + } |
| 161 | + |
| 162 | + if (size1 != 3 || size2 != 3) { |
| 163 | + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, |
| 164 | + "function {} requires arrays of size 3", get_name()); |
| 165 | + } |
| 166 | + |
| 167 | + ssize_t base1 = prev_offset1; |
| 168 | + ssize_t base2 = prev_offset2; |
| 169 | + |
| 170 | + float a1 = float1->get_data()[base1]; |
| 171 | + float a2 = float1->get_data()[base1 + 1]; |
| 172 | + float a3 = float1->get_data()[base1 + 2]; |
| 173 | + |
| 174 | + float b1 = float2->get_data()[base2]; |
| 175 | + float b2 = float2->get_data()[base2 + 1]; |
| 176 | + float b3 = float2->get_data()[base2 + 2]; |
| 177 | + |
| 178 | + float c1 = a2 * b3 - a3 * b2; |
| 179 | + float c2 = a3 * b1 - a1 * b3; |
| 180 | + float c3 = a1 * b2 - a2 * b1; |
| 181 | + |
| 182 | + nested_data[row * 3] = c1; |
| 183 | + nested_data[row * 3 + 1] = c2; |
| 184 | + nested_data[row * 3 + 2] = c3; |
| 185 | + |
| 186 | + current_offset += 3; |
| 187 | + offsets_data[row] = current_offset; |
| 188 | + |
| 189 | + prev_offset1 = offset1->get_data()[row]; |
| 190 | + prev_offset2 = offset2->get_data()[row]; |
| 191 | + } |
| 192 | + |
| 193 | + auto result_col = ColumnArray::create( |
| 194 | + ColumnNullable::create(std::move(nested_res), |
| 195 | + ColumnUInt8::create(nested_res->size(), 0)), |
| 196 | + std::move(offsets_res)); |
| 197 | + |
| 198 | + block.replace_by_position(result, std::move(result_col)); |
| 199 | + return Status::OK(); |
| 200 | + } |
| 201 | +}; |
| 202 | + |
24 | 203 | void register_function_array_cross_product(SimpleFunctionFactory& factory) { |
25 | 204 | factory.register_function<FunctionArrayCrossProduct>(); |
26 | 205 | } |
|
0 commit comments