Skip to content

Commit 3e5e023

Browse files
committed
fix some error and add table test
1 parent 8be85db commit 3e5e023

4 files changed

Lines changed: 226 additions & 227 deletions

File tree

be/src/vec/functions/array/function_array_cross_product.cpp

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,191 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
#include "vec/functions/array/function_array_cross_product.h"
18+
#include <gen_cpp/Types_types.h>
1919

20+
#include "common/exception.h"
21+
#include "common/status.h"
22+
#include "runtime/primitive_type.h"
23+
#include "vec/columns/column.h"
24+
#include "vec/columns/column_array.h"
25+
#include "vec/columns/column_nullable.h"
26+
#include "vec/common/assert_cast.h"
27+
#include "vec/core/types.h"
28+
#include "vec/data_types/data_type.h"
29+
#include "vec/data_types/data_type_array.h"
30+
#include "vec/data_types/data_type_nullable.h"
31+
#include "vec/data_types/data_type_number.h"
32+
#include "vec/functions/array/function_array_utils.h"
33+
#include "vec/functions/function.h"
2034
#include "vec/functions/simple_function_factory.h"
35+
#include "vec/utils/util.hpp"
2136

2237
namespace doris::vectorized {
2338

39+
class FunctionArrayCrossProduct : public IFunction {
40+
public:
41+
using DataType = PrimitiveTypeTraits<TYPE_FLOAT>::DataType;
42+
using ColumnType = PrimitiveTypeTraits<TYPE_FLOAT>::ColumnType;
43+
44+
static constexpr auto name = "cross_product";
45+
String get_name() const override { return name; }
46+
static FunctionPtr create() { return std::make_shared<FunctionArrayCrossProduct>(); }
47+
size_t get_number_of_arguments() const override { return 2; }
48+
49+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
50+
if (arguments.size() != 2) {
51+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
52+
"Invalid number of arguments for function {}", get_name());
53+
}
54+
55+
if (arguments[0]->get_primitive_type() != TYPE_ARRAY ||
56+
arguments[1]->get_primitive_type() != TYPE_ARRAY) {
57+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
58+
"Arguments for function {} must be arrays", get_name());
59+
}
60+
61+
// return ARRAY<FLOAT>
62+
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat32>());
63+
}
64+
65+
// strict semantics: do not allow NULL
66+
bool use_default_implementation_for_nulls() const override { return false; }
67+
68+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
69+
uint32_t result, size_t input_rows_count) const override {
70+
const auto& arg1 = block.get_by_position(arguments[0]);
71+
const auto& arg2 = block.get_by_position(arguments[1]);
72+
73+
auto col1 = arg1.column->convert_to_full_column_if_const();
74+
auto col2 = arg2.column->convert_to_full_column_if_const();
75+
76+
if (col1->size() != col2->size()) {
77+
return Status::FatalError(
78+
fmt::format("function {} have different input array sizes: {} rows and {} rows",
79+
get_name(), col1->size(), col2->size()));
80+
}
81+
82+
const ColumnArray* arr1 = nullptr;
83+
const ColumnArray* arr2 = nullptr;
84+
85+
if (col1->is_nullable()) {
86+
if (col1->has_null()) {
87+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
88+
"First argument for function {} cannot be null",
89+
get_name());
90+
}
91+
auto nullable1 = assert_cast<const ColumnNullable*>(col1.get());
92+
arr1 = assert_cast<const ColumnArray*>(nullable1->get_nested_column_ptr().get());
93+
} else {
94+
arr1 = assert_cast<const ColumnArray*>(col1.get());
95+
}
96+
97+
if (col2->is_nullable()) {
98+
if (col2->has_null()) {
99+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
100+
"Second argument for function {} cannot be null",
101+
get_name());
102+
}
103+
auto nullable2 = assert_cast<const ColumnNullable*>(col2.get());
104+
arr2 = assert_cast<const ColumnArray*>(nullable2->get_nested_column_ptr().get());
105+
} else {
106+
arr2 = assert_cast<const ColumnArray*>(col2.get());
107+
}
108+
109+
const ColumnFloat32* float1 = nullptr;
110+
const ColumnFloat32* float2 = nullptr;
111+
112+
if (arr1->get_data_ptr()->is_nullable()) {
113+
if (arr1->get_data_ptr()->has_null()) {
114+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
115+
"First argument for function {} cannot have null elements",
116+
get_name());
117+
}
118+
auto nullable1 = assert_cast<const ColumnNullable*>(arr1->get_data_ptr().get());
119+
float1 = assert_cast<const ColumnFloat32*>(nullable1->get_nested_column_ptr().get());
120+
} else {
121+
float1 = assert_cast<const ColumnFloat32*>(arr1->get_data_ptr().get());
122+
}
123+
124+
if (arr2->get_data_ptr()->is_nullable()) {
125+
if (arr2->get_data_ptr()->has_null()) {
126+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
127+
"Second argument for function {} cannot have null elements",
128+
get_name());
129+
}
130+
auto nullable2 = assert_cast<const ColumnNullable*>(arr2->get_data_ptr().get());
131+
float2 = assert_cast<const ColumnFloat32*>(nullable2->get_nested_column_ptr().get());
132+
} else {
133+
float2 = assert_cast<const ColumnFloat32*>(arr2->get_data_ptr().get());
134+
}
135+
136+
const auto* offset1 =
137+
assert_cast<const ColumnArray::ColumnOffsets*>(arr1->get_offsets_ptr().get());
138+
const auto* offset2 =
139+
assert_cast<const ColumnArray::ColumnOffsets*>(arr2->get_offsets_ptr().get());
140+
141+
// prepare result data
142+
auto nested_res = ColumnFloat32::create();
143+
auto& nested_data = nested_res->get_data();
144+
nested_data.resize(3 * input_rows_count);
145+
auto offsets_res = ColumnArray::ColumnOffsets::create();
146+
auto& offsets_data = offsets_res->get_data();
147+
offsets_data.resize(input_rows_count);
148+
size_t current_offset = 0;
149+
150+
size_t prev_offset1 = 0;
151+
size_t prev_offset2 = 0;
152+
for (ssize_t row = 0; row < input_rows_count; ++row) {
153+
ssize_t size1 = offset1->get_data()[row] - prev_offset1;
154+
ssize_t size2 = offset2->get_data()[row] - prev_offset2;
155+
156+
if (size1 != size2) {
157+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
158+
"function {} have different input element sizes of array: {} and {}",
159+
get_name(), size1, size2);
160+
}
161+
162+
if (size1 != 3 || size2 != 3) {
163+
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
164+
"function {} requires arrays of size 3", get_name());
165+
}
166+
167+
ssize_t base1 = prev_offset1;
168+
ssize_t base2 = prev_offset2;
169+
170+
float a1 = float1->get_data()[base1];
171+
float a2 = float1->get_data()[base1 + 1];
172+
float a3 = float1->get_data()[base1 + 2];
173+
174+
float b1 = float2->get_data()[base2];
175+
float b2 = float2->get_data()[base2 + 1];
176+
float b3 = float2->get_data()[base2 + 2];
177+
178+
float c1 = a2 * b3 - a3 * b2;
179+
float c2 = a3 * b1 - a1 * b3;
180+
float c3 = a1 * b2 - a2 * b1;
181+
182+
nested_data[row * 3] = c1;
183+
nested_data[row * 3 + 1] = c2;
184+
nested_data[row * 3 + 2] = c3;
185+
186+
current_offset += 3;
187+
offsets_data[row] = current_offset;
188+
189+
prev_offset1 = offset1->get_data()[row];
190+
prev_offset2 = offset2->get_data()[row];
191+
}
192+
193+
auto result_col = ColumnArray::create(
194+
ColumnNullable::create(std::move(nested_res),
195+
ColumnUInt8::create(nested_res->size(), 0)),
196+
std::move(offsets_res));
197+
198+
block.replace_by_position(result, std::move(result_col));
199+
return Status::OK();
200+
}
201+
};
202+
24203
void register_function_array_cross_product(SimpleFunctionFactory& factory) {
25204
factory.register_function<FunctionArrayCrossProduct>();
26205
}

be/src/vec/functions/array/function_array_cross_product.h

Lines changed: 0 additions & 204 deletions
This file was deleted.

0 commit comments

Comments
 (0)