Skip to content

Commit a09cf74

Browse files
committed
feat: add GetSelectivityOfSQLColumns method and utility functions for extracting column selectivity from SQL statements
1 parent d129908 commit a09cf74

File tree

2 files changed

+227
-0
lines changed

2 files changed

+227
-0
lines changed

sqle/driver/mysql/mysql.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,83 @@ func (p *PluginProcessor) GetDriverMetas() (*driverV2.DriverMetas, error) {
601601
}, nil
602602
}
603603

604+
func (i *MysqlDriverImpl) GetSelectivityOfSQLColumns(ctx context.Context, sql string) (map[string]map[string]float32, error) {
605+
node, err := util.ParseOneSql(sql)
606+
if err != nil {
607+
return nil, err
608+
}
609+
610+
if _, ok := node.(*ast.SelectStmt); !ok {
611+
log.NewEntry().Errorf("get selectivity of sql columns failed, sql is not a select statement, sql: %s", sql)
612+
return nil, nil
613+
}
614+
615+
selectVisitor := &util.SelectVisitor{}
616+
node.Accept(selectVisitor)
617+
618+
result := make(map[string]map[string]float32)
619+
620+
for _, selectNode := range selectVisitor.SelectList {
621+
if selectNode.From == nil || selectNode.From.TableRefs == nil {
622+
continue
623+
}
624+
625+
// 获取表别名映射关系
626+
aliasInfo := util.GetTableAliasInfoFromJoin(selectNode.From.TableRefs)
627+
aliasMap := make(map[string]string)
628+
allTables := make([]string, 0, len(aliasInfo))
629+
630+
for _, alias := range aliasInfo {
631+
if alias.TableAliasName != "" {
632+
aliasMap[alias.TableAliasName] = alias.TableName
633+
}
634+
allTables = append(allTables, alias.TableName)
635+
}
636+
637+
// 提取列并按表分组
638+
tableColumns := util.ExtractColumnsFromSelectStmt(selectNode, aliasMap, allTables)
639+
640+
// 遍历每个表,获取其列的选择性
641+
for tableName, columnSet := range tableColumns {
642+
columns := make([]string, 0, len(columnSet))
643+
for colName := range columnSet {
644+
columns = append(columns, colName)
645+
}
646+
647+
if len(columns) == 0 {
648+
continue
649+
}
650+
651+
// 构造 TableName 对象
652+
var schemaName string
653+
for _, alias := range aliasInfo {
654+
if alias.TableName == tableName {
655+
schemaName = alias.SchemaName
656+
break
657+
}
658+
}
659+
tableNameObj := util.NewTableName(schemaName, tableName)
660+
661+
columnSelectivityMap, err := i.Ctx.GetSelectivityOfColumns(tableNameObj, columns)
662+
if err != nil {
663+
log.NewEntry().Errorf("get selectivity of columns failed, table: %s, columns: %v, error: %v", tableName, columns, err)
664+
continue
665+
}
666+
667+
if result[tableName] == nil {
668+
result[tableName] = make(map[string]float32)
669+
}
670+
for columnName, selectivity := range columnSelectivityMap {
671+
if selectivity > 0 {
672+
result[tableName][columnName] = float32(selectivity)
673+
}
674+
}
675+
}
676+
}
677+
678+
return result, nil
679+
}
680+
604681
func (p *PluginProcessor) Open(l *logrus.Entry, cfg *driverV2.Config) (driver.Plugin, error) {
605682
return NewInspect(l, cfg)
606683
}

sqle/driver/mysql/util/parser_helper.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,3 +895,153 @@ func ConvertAliasToTable(alias string, tables []*ast.TableSource) (*ast.TableNam
895895
}
896896
return nil, errors.New("can not find table")
897897
}
898+
899+
// TableColumnMap 表示按表分组的列名集合
900+
type TableColumnMap map[string]map[string]struct{}
901+
902+
// ExtractColumnsFromSelectStmt 从 SELECT 语句中提取列,并按表分组
903+
// 参数:
904+
// - selectStmt: SELECT 语句节点
905+
// - aliasMap: 表别名到实际表名的映射
906+
// - allTables: 所有涉及的表名列表(用于处理无表前缀的列)
907+
//
908+
// 返回:按表名分组的列名集合
909+
func ExtractColumnsFromSelectStmt(selectStmt *ast.SelectStmt, aliasMap map[string]string, allTables []string) TableColumnMap {
910+
tableColumns := make(TableColumnMap)
911+
912+
// 收集 SELECT 列表中的所有列别名
913+
selectAliases := make(map[string]struct{})
914+
if selectStmt.Fields != nil {
915+
for _, field := range selectStmt.Fields.Fields {
916+
if field.AsName.L != "" {
917+
selectAliases[field.AsName.L] = struct{}{}
918+
}
919+
}
920+
}
921+
922+
// 辅助函数:从表达式中提取列并按表分组
923+
extractColumnsFromExpr := func(expr ast.Node, skipAliases bool) {
924+
if expr == nil {
925+
return
926+
}
927+
columnVisitor := &ColumnNameVisitor{}
928+
expr.Accept(columnVisitor)
929+
930+
for _, colExpr := range columnVisitor.ColumnNameList {
931+
if colExpr.Name == nil {
932+
continue
933+
}
934+
935+
// 如果需要跳过别名且当前列名是一个别名,则跳过
936+
if skipAliases {
937+
if _, isAlias := selectAliases[colExpr.Name.Name.L]; isAlias && colExpr.Name.Table.L == "" {
938+
continue
939+
}
940+
}
941+
942+
var targetTableName string
943+
944+
// 如果列有表前缀(可能是别名或实际表名)
945+
if colExpr.Name.Table.L != "" {
946+
// 先尝试从别名映射中查找
947+
if actualTable, exists := aliasMap[colExpr.Name.Table.L]; exists {
948+
targetTableName = actualTable
949+
} else {
950+
// 如果不是别名,就当作实际表名
951+
targetTableName = colExpr.Name.Table.L
952+
}
953+
}
954+
955+
if targetTableName != "" {
956+
if tableColumns[targetTableName] == nil {
957+
tableColumns[targetTableName] = make(map[string]struct{})
958+
}
959+
tableColumns[targetTableName][colExpr.Name.Name.L] = struct{}{}
960+
} else {
961+
// 没有表前缀的列,可能属于任何表
962+
// 在多表查询中,尝试将该列添加到所有表
963+
for _, tableName := range allTables {
964+
if tableColumns[tableName] == nil {
965+
tableColumns[tableName] = make(map[string]struct{})
966+
}
967+
tableColumns[tableName][colExpr.Name.Name.L] = struct{}{}
968+
}
969+
}
970+
}
971+
}
972+
973+
// 从 SELECT Fields 提取列(包括聚合函数内的列)
974+
if selectStmt.Fields != nil {
975+
for _, field := range selectStmt.Fields.Fields {
976+
extractColumnsFromExpr(field.Expr, false)
977+
}
978+
}
979+
980+
// 从 WHERE 条件提取列
981+
if selectStmt.Where != nil {
982+
extractColumnsFromExpr(selectStmt.Where, false)
983+
}
984+
985+
// 从 GROUP BY 提取列(需要跳过别名引用)
986+
if selectStmt.GroupBy != nil {
987+
for _, item := range selectStmt.GroupBy.Items {
988+
extractColumnsFromExpr(item.Expr, true)
989+
}
990+
}
991+
992+
// 从 HAVING 提取列
993+
if selectStmt.Having != nil {
994+
extractColumnsFromExpr(selectStmt.Having.Expr, false)
995+
}
996+
997+
// 注意:不从 ORDER BY 提取,因为可能包含别名引用
998+
999+
return tableColumns
1000+
}
1001+
1002+
// a helper function to get the table alias info from join node
1003+
func GetTableAliasInfoFromJoin(join *ast.Join) []*TableAliasInfo {
1004+
tableAlias := make([]*TableAliasInfo, 0)
1005+
tableSources := GetTableSourcesFromJoin(join)
1006+
for _, tableSource := range tableSources {
1007+
if tableName, ok := tableSource.Source.(*ast.TableName); ok {
1008+
tableAlias = append(tableAlias, &TableAliasInfo{
1009+
TableAliasName: tableSource.AsName.String(),
1010+
TableName: tableName.Name.O,
1011+
SchemaName: tableName.Schema.O,
1012+
})
1013+
}
1014+
}
1015+
return tableAlias
1016+
}
1017+
1018+
type TableAliasInfo struct {
1019+
TableName string
1020+
SchemaName string
1021+
TableAliasName string
1022+
}
1023+
1024+
// a helper function to get the table source from join node
1025+
func GetTableSourcesFromJoin(join *ast.Join) []*ast.TableSource {
1026+
sources := []*ast.TableSource{}
1027+
if join == nil {
1028+
return sources
1029+
}
1030+
if n := join.Left; n != nil {
1031+
switch t := n.(type) {
1032+
case *ast.TableSource:
1033+
sources = append(sources, t)
1034+
case *ast.Join:
1035+
sources = append(sources, GetTableSourcesFromJoin(t)...)
1036+
}
1037+
}
1038+
if n := join.Right; n != nil {
1039+
switch t := n.(type) {
1040+
case *ast.TableSource:
1041+
sources = append(sources, t)
1042+
case *ast.Join:
1043+
sources = append(sources, GetTableSourcesFromJoin(t)...)
1044+
}
1045+
}
1046+
return sources
1047+
}

0 commit comments

Comments
 (0)