From a11986ac0cbd724689975a82740402bc159e8f00 Mon Sep 17 00:00:00 2001 From: ditatompel Date: Thu, 30 May 2024 19:19:03 +0700 Subject: [PATCH] Separating QueryNode to it's own sql builder func A quick note that is's ok to use "*" all columns on nodes query since all columns really want to be displayed to the frontend. --- internal/monero/monero.go | 171 ++++++++++----------------------- internal/monero/monero_test.go | 115 ++++++++++++++++++++++ internal/monero/prober_test.go | 31 ------ 3 files changed, 166 insertions(+), 151 deletions(-) create mode 100644 internal/monero/monero_test.go diff --git a/internal/monero/monero.go b/internal/monero/monero.go index 6ebb905..36f8c7f 100644 --- a/internal/monero/monero.go +++ b/internal/monero/monero.go @@ -69,11 +69,11 @@ type Node struct { } // Get node from database by id -func (repo *MoneroRepo) Node(id int) (Node, error) { +func (r *MoneroRepo) Node(id int) (Node, error) { var node Node - err := repo.db.Get(&node, `SELECT * FROM tbl_node WHERE id = ?`, id) + err := r.db.Get(&node, `SELECT * FROM tbl_node WHERE id = ?`, id) if err != nil && err != sql.ErrNoRows { - fmt.Println("WARN:", err) + slog.Error("WARN:", err) return node, errors.New("Can't get node information") } if err == sql.ErrNoRows { @@ -92,8 +92,8 @@ type Nodes struct { // QueryNodes represents database query parameters type QueryNodes struct { Host string - Nettype string - Protocol string + Nettype string // Can be "any", mainnet, stagenet, testnet. Default: "any" + Protocol string // Can be "any", tor, http, https. Default: "any" CC string // 2 letter country code Status int CORS int @@ -105,112 +105,85 @@ type QueryNodes struct { SortDirection string } -// Get nodes from database -func (repo *MoneroRepo) Nodes(q QueryNodes) (Nodes, error) { - queryParams := []interface{}{} - whereQueries := []string{} - where := "" +// toSQL generates SQL query from query parameters +func (q QueryNodes) toSQL() (args []interface{}, where, sortBy, sortDirection string) { + wq := []string{} if q.Host != "" { - whereQueries = append(whereQueries, "(hostname LIKE ? OR ip_addr LIKE ?)") - queryParams = append(queryParams, "%"+q.Host+"%") - queryParams = append(queryParams, "%"+q.Host+"%") + wq = append(wq, "(hostname LIKE ? OR ip_addr LIKE ?)") + args = append(args, "%"+q.Host+"%", "%"+q.Host+"%") } if q.Nettype != "any" { - if q.Nettype != "mainnet" && q.Nettype != "stagenet" && q.Nettype != "testnet" { - return Nodes{}, errors.New("Invalid nettype, must be one of 'mainnet', 'stagenet', 'testnet' or 'any'") + if q.Nettype == "mainnet" || q.Nettype == "stagenet" || q.Nettype == "testnet" { + wq = append(wq, "nettype = ?") + args = append(args, q.Nettype) } - whereQueries = append(whereQueries, "nettype = ?") - queryParams = append(queryParams, q.Nettype) } - if q.Protocol != "any" { - allowedProtocols := []string{"tor", "http", "https"} - if !slices.Contains(allowedProtocols, q.Protocol) { - return Nodes{}, errors.New("Invalid protocol, must be one of '" + strings.Join(allowedProtocols, "', '") + "' or 'any'") - } + if q.Protocol != "any" && slices.Contains([]string{"tor", "http", "https"}, q.Protocol) { if q.Protocol == "tor" { - whereQueries = append(whereQueries, "is_tor = ?") - queryParams = append(queryParams, 1) + wq = append(wq, "is_tor = ?") + args = append(args, 1) } else { - whereQueries = append(whereQueries, "(protocol = ? AND is_tor = ?)") - queryParams = append(queryParams, q.Protocol) - queryParams = append(queryParams, 0) + wq = append(wq, "(protocol = ? AND is_tor = ?)") + args = append(args, q.Protocol, 0) } } if q.CC != "any" { - whereQueries = append(whereQueries, "country = ?") + wq = append(wq, "country = ?") if q.CC == "UNKNOWN" { - queryParams = append(queryParams, "") + args = append(args, "") } else { - queryParams = append(queryParams, q.CC) + args = append(args, q.CC) } } if q.Status != -1 { - whereQueries = append(whereQueries, "is_available = ?") - queryParams = append(queryParams, q.Status) + wq = append(wq, "is_available = ?") + args = append(args, q.Status) } if q.CORS != -1 { - whereQueries = append(whereQueries, "cors_capable = ?") - queryParams = append(queryParams, 1) + wq = append(wq, "cors_capable = ?") + args = append(args, q.CORS) } - if len(whereQueries) > 0 { - where = "WHERE " + strings.Join(whereQueries, " AND ") + if len(wq) > 0 { + where = "WHERE " + strings.Join(wq, " AND ") } - nodes := Nodes{} + as := []string{"last_checked", "uptime"} + sortBy = "last_checked" + if slices.Contains(as, q.SortBy) { + sortBy = q.SortBy + } + sortDirection = "DESC" + if q.SortDirection == "asc" { + sortDirection = "ASC" + } - queryTotalRows := fmt.Sprintf(` + return args, where, sortBy, sortDirection +} + +// Get nodes from database +func (r *MoneroRepo) Nodes(q QueryNodes) (Nodes, error) { + args, where, sortBy, sortDirection := q.toSQL() + + var nodes Nodes + + qTotal := fmt.Sprintf(` SELECT COUNT(id) AS total_rows FROM tbl_node %s`, where) - err := repo.db.QueryRow(queryTotalRows, queryParams...).Scan(&nodes.TotalRows) + err := r.db.QueryRow(qTotal, args...).Scan(&nodes.TotalRows) if err != nil { return nodes, err } - queryParams = append(queryParams, q.RowsPerPage, (q.Page-1)*q.RowsPerPage) - - allowedSort := []string{"last_checked", "uptime"} - sortBy := "last_checked" - if slices.Contains(allowedSort, q.SortBy) { - sortBy = q.SortBy - } - sortDirection := "DESC" - if q.SortDirection == "asc" { - sortDirection = "ASC" - } + args = append(args, q.RowsPerPage, (q.Page-1)*q.RowsPerPage) query := fmt.Sprintf(` SELECT - id, - protocol, - hostname, - port, - is_tor, - is_available, - nettype, - height, - adjusted_time, - database_size, - difficulty, - version, - uptime, - estimate_fee, - ip_addr, - asn, - asn_name, - country, - country_name, - city, - lat, - lon, - date_entered, - last_checked, - last_check_status, - cors_capable + * FROM tbl_node %s -- where query if any @@ -219,51 +192,9 @@ func (repo *MoneroRepo) Nodes(q QueryNodes) (Nodes, error) { %s LIMIT ? OFFSET ?`, where, sortBy, sortDirection) + err = r.db.Select(&nodes.Items, query, args...) - row, err := repo.db.Query(query, queryParams...) - if err != nil { - return nodes, err - } - defer row.Close() - - nodes.RowsPerPage = q.RowsPerPage - - for row.Next() { - var node Node - err = row.Scan( - &node.ID, - &node.Protocol, - &node.Hostname, - &node.Port, - &node.IsTor, - &node.IsAvailable, - &node.Nettype, - &node.Height, - &node.AdjustedTime, - &node.DatabaseSize, - &node.Difficulty, - &node.Version, - &node.Uptime, - &node.EstimateFee, - &node.IP, - &node.ASN, - &node.ASNName, - &node.CountryCode, - &node.CountryName, - &node.City, - &node.Latitude, - &node.Longitude, - &node.DateEntered, - &node.LastChecked, - &node.LastCheckStatus, - &node.CORSCapable) - if err != nil { - return nodes, err - } - nodes.Items = append(nodes.Items, &node) - } - - return nodes, nil + return nodes, err } type QueryLogs struct { diff --git a/internal/monero/monero_test.go b/internal/monero/monero_test.go new file mode 100644 index 0000000..4fcd544 --- /dev/null +++ b/internal/monero/monero_test.go @@ -0,0 +1,115 @@ +package monero + +import ( + "os" + "reflect" + "strconv" + "testing" + "xmr-remote-nodes/internal/config" + "xmr-remote-nodes/internal/database" +) + +var testMySQL = true + +// TODO: Add database test table and then clean it up +func init() { + // load test db config from OS environment variable + // + // Example: + // TEST_DB_HOST=127.0.0.1 \ + // TEST_DB_PORT=3306 \ + // TEST_DB_USER=testuser \ + // TEST_DB_PASSWORD=testpass \ + // TEST_DB_NAME=testdb go test ./... -v + // + // To run benchmark only, add `-bench=. -run=^#` to the `go test` command + config.DBCfg().Host = os.Getenv("TEST_DB_HOST") + config.DBCfg().Port, _ = strconv.Atoi(os.Getenv("TEST_DB_PORT")) + config.DBCfg().User = os.Getenv("TEST_DB_USER") + config.DBCfg().Password = os.Getenv("TEST_DB_PASSWORD") + config.DBCfg().Name = os.Getenv("TEST_DB_NAME") + + if err := database.ConnectDB(); err != nil { + testMySQL = false + } +} + +func TestQueryNodes_toSQL(t *testing.T) { + tests := []struct { + name string + query QueryNodes + wantArgs []interface{} + wantWhere string + wantSortBy string + wantSortDirection string + }{ + { + name: "Default query", + query: QueryNodes{ + Host: "", + Nettype: "any", + Protocol: "any", + CC: "any", + Status: -1, + CORS: -1, + RowsPerPage: 10, + Page: 1, + SortBy: "last_checked", + SortDirection: "desc", + }, + wantArgs: []interface{}{}, + wantWhere: "", + wantSortBy: "last_checked", + wantSortDirection: "DESC", + }, + { + name: "With host query", + query: QueryNodes{ + Host: "test", + Nettype: "any", + Protocol: "any", + CC: "any", + Status: -1, + CORS: -1, + RowsPerPage: 10, + Page: 1, + SortBy: "last_checked", + SortDirection: "desc", + }, + wantArgs: []interface{}{"%test%", "%test%"}, + wantWhere: "WHERE (hostname LIKE ? OR ip_addr LIKE ?)", + wantSortBy: "last_checked", + wantSortDirection: "DESC", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotArgs, gotWhere, gotSortBy, gotSortDirection := tt.query.toSQL() + + if !equalArgs(gotArgs, tt.wantArgs) { + t.Errorf("QueryNodes.toSQL() gotArgs = %v, want %v", gotArgs, tt.wantArgs) + } + if gotWhere != tt.wantWhere { + t.Errorf("QueryNodes.toSQL() gotWhere = %v, want %v", gotWhere, tt.wantWhere) + } + if gotSortBy != tt.wantSortBy { + t.Errorf("QueryNodes.toSQL() gotSortBy = %v, want %v", gotSortBy, tt.wantSortBy) + } + if gotSortDirection != tt.wantSortDirection { + t.Errorf("QueryNodes.toSQL() gotSortDirection = %v, want %v", gotSortDirection, tt.wantSortDirection) + } + }) + } +} + +func equalArgs(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if !reflect.DeepEqual(v, b[i]) { + return false + } + } + return true +} diff --git a/internal/monero/prober_test.go b/internal/monero/prober_test.go index ac34237..8ffab39 100644 --- a/internal/monero/prober_test.go +++ b/internal/monero/prober_test.go @@ -1,43 +1,13 @@ package monero import ( - "fmt" - "os" - "strconv" "testing" - "xmr-remote-nodes/internal/config" - "xmr-remote-nodes/internal/database" ) -var testMySQL = true - -func init() { - // load test db config from OS environment variable - // - // Example: - // TEST_DB_HOST=127.0.0.1 \ - // TEST_DB_PORT=3306 \ - // TEST_DB_USER=testuser \ - // TEST_DB_PASSWORD=testpass \ - // TEST_DB_NAME=testdb go test ./... -v - // - // To run benchmark only, add `-bench=. -run=^#` to the `go test` command - config.DBCfg().Host = os.Getenv("TEST_DB_HOST") - config.DBCfg().Port, _ = strconv.Atoi(os.Getenv("TEST_DB_PORT")) - config.DBCfg().User = os.Getenv("TEST_DB_USER") - config.DBCfg().Password = os.Getenv("TEST_DB_PASSWORD") - config.DBCfg().Name = os.Getenv("TEST_DB_NAME") - - if err := database.ConnectDB(); err != nil { - testMySQL = false - } -} - // TODO: Add database test table and then clean it up func TestProberRepo_CheckApi(t *testing.T) { if !testMySQL { - fmt.Println("Skip test, not connected to database") t.Skip("Skip test, not connected to database") } tests := []struct { @@ -74,7 +44,6 @@ func TestProberRepo_CheckApi(t *testing.T) { func BenchmarkProberRepo_CheckApi(b *testing.B) { if !testMySQL { - fmt.Println("Skip bench, not connected to database") b.Skip("Skip bench, not connected to database") } repo := NewProber()