@@ -895,3 +895,153 @@ func ConvertAliasToTable(alias string, tables []*ast.TableSource) (*ast.TableNam
895895 }
896896 return nil , errors .New ("can not find table" )
897897}
898+
899+ // TableColumnMap 表示按表分组的列名集合
900+ type TableColumnMap map [string ]map [string ]struct {}
901+
902+ // ExtractColumnsFromSelectStmt 从 SELECT 语句中提取列,并按表分组
903+ // 参数:
904+ // - selectStmt: SELECT 语句节点
905+ // - aliasMap: 表别名到实际表名的映射
906+ // - allTables: 所有涉及的表名列表(用于处理无表前缀的列)
907+ //
908+ // 返回:按表名分组的列名集合
909+ func ExtractColumnsFromSelectStmt (selectStmt * ast.SelectStmt , aliasMap map [string ]string , allTables []string ) TableColumnMap {
910+ tableColumns := make (TableColumnMap )
911+
912+ // 收集 SELECT 列表中的所有列别名
913+ selectAliases := make (map [string ]struct {})
914+ if selectStmt .Fields != nil {
915+ for _ , field := range selectStmt .Fields .Fields {
916+ if field .AsName .L != "" {
917+ selectAliases [field .AsName .L ] = struct {}{}
918+ }
919+ }
920+ }
921+
922+ // 辅助函数:从表达式中提取列并按表分组
923+ extractColumnsFromExpr := func (expr ast.Node , skipAliases bool ) {
924+ if expr == nil {
925+ return
926+ }
927+ columnVisitor := & ColumnNameVisitor {}
928+ expr .Accept (columnVisitor )
929+
930+ for _ , colExpr := range columnVisitor .ColumnNameList {
931+ if colExpr .Name == nil {
932+ continue
933+ }
934+
935+ // 如果需要跳过别名且当前列名是一个别名,则跳过
936+ if skipAliases {
937+ if _ , isAlias := selectAliases [colExpr .Name .Name .L ]; isAlias && colExpr .Name .Table .L == "" {
938+ continue
939+ }
940+ }
941+
942+ var targetTableName string
943+
944+ // 如果列有表前缀(可能是别名或实际表名)
945+ if colExpr .Name .Table .L != "" {
946+ // 先尝试从别名映射中查找
947+ if actualTable , exists := aliasMap [colExpr .Name .Table .L ]; exists {
948+ targetTableName = actualTable
949+ } else {
950+ // 如果不是别名,就当作实际表名
951+ targetTableName = colExpr .Name .Table .L
952+ }
953+ }
954+
955+ if targetTableName != "" {
956+ if tableColumns [targetTableName ] == nil {
957+ tableColumns [targetTableName ] = make (map [string ]struct {})
958+ }
959+ tableColumns [targetTableName ][colExpr .Name .Name .L ] = struct {}{}
960+ } else {
961+ // 没有表前缀的列,可能属于任何表
962+ // 在多表查询中,尝试将该列添加到所有表
963+ for _ , tableName := range allTables {
964+ if tableColumns [tableName ] == nil {
965+ tableColumns [tableName ] = make (map [string ]struct {})
966+ }
967+ tableColumns [tableName ][colExpr .Name .Name .L ] = struct {}{}
968+ }
969+ }
970+ }
971+ }
972+
973+ // 从 SELECT Fields 提取列(包括聚合函数内的列)
974+ if selectStmt .Fields != nil {
975+ for _ , field := range selectStmt .Fields .Fields {
976+ extractColumnsFromExpr (field .Expr , false )
977+ }
978+ }
979+
980+ // 从 WHERE 条件提取列
981+ if selectStmt .Where != nil {
982+ extractColumnsFromExpr (selectStmt .Where , false )
983+ }
984+
985+ // 从 GROUP BY 提取列(需要跳过别名引用)
986+ if selectStmt .GroupBy != nil {
987+ for _ , item := range selectStmt .GroupBy .Items {
988+ extractColumnsFromExpr (item .Expr , true )
989+ }
990+ }
991+
992+ // 从 HAVING 提取列
993+ if selectStmt .Having != nil {
994+ extractColumnsFromExpr (selectStmt .Having .Expr , false )
995+ }
996+
997+ // 注意:不从 ORDER BY 提取,因为可能包含别名引用
998+
999+ return tableColumns
1000+ }
1001+
1002+ // a helper function to get the table alias info from join node
1003+ func GetTableAliasInfoFromJoin (join * ast.Join ) []* TableAliasInfo {
1004+ tableAlias := make ([]* TableAliasInfo , 0 )
1005+ tableSources := GetTableSourcesFromJoin (join )
1006+ for _ , tableSource := range tableSources {
1007+ if tableName , ok := tableSource .Source .(* ast.TableName ); ok {
1008+ tableAlias = append (tableAlias , & TableAliasInfo {
1009+ TableAliasName : tableSource .AsName .String (),
1010+ TableName : tableName .Name .O ,
1011+ SchemaName : tableName .Schema .O ,
1012+ })
1013+ }
1014+ }
1015+ return tableAlias
1016+ }
1017+
1018+ type TableAliasInfo struct {
1019+ TableName string
1020+ SchemaName string
1021+ TableAliasName string
1022+ }
1023+
1024+ // a helper function to get the table source from join node
1025+ func GetTableSourcesFromJoin (join * ast.Join ) []* ast.TableSource {
1026+ sources := []* ast.TableSource {}
1027+ if join == nil {
1028+ return sources
1029+ }
1030+ if n := join .Left ; n != nil {
1031+ switch t := n .(type ) {
1032+ case * ast.TableSource :
1033+ sources = append (sources , t )
1034+ case * ast.Join :
1035+ sources = append (sources , GetTableSourcesFromJoin (t )... )
1036+ }
1037+ }
1038+ if n := join .Right ; n != nil {
1039+ switch t := n .(type ) {
1040+ case * ast.TableSource :
1041+ sources = append (sources , t )
1042+ case * ast.Join :
1043+ sources = append (sources , GetTableSourcesFromJoin (t )... )
1044+ }
1045+ }
1046+ return sources
1047+ }
0 commit comments