Separating QueryNode to it's own sql builder func

A quick note that is's ok to use "*" all columns on nodes query
since all columns really want to be displayed to the frontend.
This commit is contained in:
Cristian Ditaputratama 2024-05-30 19:19:03 +07:00
parent 4800bb3284
commit a11986ac0c
Signed by: ditatompel
GPG key ID: 31D3D06D77950979
3 changed files with 166 additions and 151 deletions

View file

@ -69,11 +69,11 @@ type Node struct {
} }
// Get node from database by id // Get node from database by id
func (repo *MoneroRepo) Node(id int) (Node, error) { func (r *MoneroRepo) Node(id int) (Node, error) {
var node Node var node Node
err := repo.db.Get(&node, `SELECT * FROM tbl_node WHERE id = ?`, id) err := r.db.Get(&node, `SELECT * FROM tbl_node WHERE id = ?`, id)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
fmt.Println("WARN:", err) slog.Error("WARN:", err)
return node, errors.New("Can't get node information") return node, errors.New("Can't get node information")
} }
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -92,8 +92,8 @@ type Nodes struct {
// QueryNodes represents database query parameters // QueryNodes represents database query parameters
type QueryNodes struct { type QueryNodes struct {
Host string Host string
Nettype string Nettype string // Can be "any", mainnet, stagenet, testnet. Default: "any"
Protocol string Protocol string // Can be "any", tor, http, https. Default: "any"
CC string // 2 letter country code CC string // 2 letter country code
Status int Status int
CORS int CORS int
@ -105,112 +105,85 @@ type QueryNodes struct {
SortDirection string SortDirection string
} }
// Get nodes from database // toSQL generates SQL query from query parameters
func (repo *MoneroRepo) Nodes(q QueryNodes) (Nodes, error) { func (q QueryNodes) toSQL() (args []interface{}, where, sortBy, sortDirection string) {
queryParams := []interface{}{} wq := []string{}
whereQueries := []string{}
where := ""
if q.Host != "" { if q.Host != "" {
whereQueries = append(whereQueries, "(hostname LIKE ? OR ip_addr LIKE ?)") wq = append(wq, "(hostname LIKE ? OR ip_addr LIKE ?)")
queryParams = append(queryParams, "%"+q.Host+"%") args = append(args, "%"+q.Host+"%", "%"+q.Host+"%")
queryParams = append(queryParams, "%"+q.Host+"%")
} }
if q.Nettype != "any" { if q.Nettype != "any" {
if q.Nettype != "mainnet" && q.Nettype != "stagenet" && q.Nettype != "testnet" { if q.Nettype == "mainnet" || q.Nettype == "stagenet" || q.Nettype == "testnet" {
return Nodes{}, errors.New("Invalid nettype, must be one of 'mainnet', 'stagenet', 'testnet' or 'any'") wq = append(wq, "nettype = ?")
args = append(args, q.Nettype)
} }
whereQueries = append(whereQueries, "nettype = ?")
queryParams = append(queryParams, q.Nettype)
} }
if q.Protocol != "any" { if q.Protocol != "any" && slices.Contains([]string{"tor", "http", "https"}, q.Protocol) {
allowedProtocols := []string{"tor", "http", "https"}
if !slices.Contains(allowedProtocols, q.Protocol) {
return Nodes{}, errors.New("Invalid protocol, must be one of '" + strings.Join(allowedProtocols, "', '") + "' or 'any'")
}
if q.Protocol == "tor" { if q.Protocol == "tor" {
whereQueries = append(whereQueries, "is_tor = ?") wq = append(wq, "is_tor = ?")
queryParams = append(queryParams, 1) args = append(args, 1)
} else { } else {
whereQueries = append(whereQueries, "(protocol = ? AND is_tor = ?)") wq = append(wq, "(protocol = ? AND is_tor = ?)")
queryParams = append(queryParams, q.Protocol) args = append(args, q.Protocol, 0)
queryParams = append(queryParams, 0)
} }
} }
if q.CC != "any" { if q.CC != "any" {
whereQueries = append(whereQueries, "country = ?") wq = append(wq, "country = ?")
if q.CC == "UNKNOWN" { if q.CC == "UNKNOWN" {
queryParams = append(queryParams, "") args = append(args, "")
} else { } else {
queryParams = append(queryParams, q.CC) args = append(args, q.CC)
} }
} }
if q.Status != -1 { if q.Status != -1 {
whereQueries = append(whereQueries, "is_available = ?") wq = append(wq, "is_available = ?")
queryParams = append(queryParams, q.Status) args = append(args, q.Status)
} }
if q.CORS != -1 { if q.CORS != -1 {
whereQueries = append(whereQueries, "cors_capable = ?") wq = append(wq, "cors_capable = ?")
queryParams = append(queryParams, 1) args = append(args, q.CORS)
} }
if len(whereQueries) > 0 { if len(wq) > 0 {
where = "WHERE " + strings.Join(whereQueries, " AND ") where = "WHERE " + strings.Join(wq, " AND ")
} }
nodes := Nodes{} as := []string{"last_checked", "uptime"}
sortBy = "last_checked"
if slices.Contains(as, q.SortBy) {
sortBy = q.SortBy
}
sortDirection = "DESC"
if q.SortDirection == "asc" {
sortDirection = "ASC"
}
queryTotalRows := fmt.Sprintf(` return args, where, sortBy, sortDirection
}
// Get nodes from database
func (r *MoneroRepo) Nodes(q QueryNodes) (Nodes, error) {
args, where, sortBy, sortDirection := q.toSQL()
var nodes Nodes
qTotal := fmt.Sprintf(`
SELECT SELECT
COUNT(id) AS total_rows COUNT(id) AS total_rows
FROM FROM
tbl_node tbl_node
%s`, where) %s`, where)
err := repo.db.QueryRow(queryTotalRows, queryParams...).Scan(&nodes.TotalRows) err := r.db.QueryRow(qTotal, args...).Scan(&nodes.TotalRows)
if err != nil { if err != nil {
return nodes, err return nodes, err
} }
queryParams = append(queryParams, q.RowsPerPage, (q.Page-1)*q.RowsPerPage) args = append(args, q.RowsPerPage, (q.Page-1)*q.RowsPerPage)
allowedSort := []string{"last_checked", "uptime"}
sortBy := "last_checked"
if slices.Contains(allowedSort, q.SortBy) {
sortBy = q.SortBy
}
sortDirection := "DESC"
if q.SortDirection == "asc" {
sortDirection = "ASC"
}
query := fmt.Sprintf(` query := fmt.Sprintf(`
SELECT SELECT
id, *
protocol,
hostname,
port,
is_tor,
is_available,
nettype,
height,
adjusted_time,
database_size,
difficulty,
version,
uptime,
estimate_fee,
ip_addr,
asn,
asn_name,
country,
country_name,
city,
lat,
lon,
date_entered,
last_checked,
last_check_status,
cors_capable
FROM FROM
tbl_node tbl_node
%s -- where query if any %s -- where query if any
@ -219,51 +192,9 @@ func (repo *MoneroRepo) Nodes(q QueryNodes) (Nodes, error) {
%s %s
LIMIT ? LIMIT ?
OFFSET ?`, where, sortBy, sortDirection) OFFSET ?`, where, sortBy, sortDirection)
err = r.db.Select(&nodes.Items, query, args...)
row, err := repo.db.Query(query, queryParams...) return nodes, err
if err != nil {
return nodes, err
}
defer row.Close()
nodes.RowsPerPage = q.RowsPerPage
for row.Next() {
var node Node
err = row.Scan(
&node.ID,
&node.Protocol,
&node.Hostname,
&node.Port,
&node.IsTor,
&node.IsAvailable,
&node.Nettype,
&node.Height,
&node.AdjustedTime,
&node.DatabaseSize,
&node.Difficulty,
&node.Version,
&node.Uptime,
&node.EstimateFee,
&node.IP,
&node.ASN,
&node.ASNName,
&node.CountryCode,
&node.CountryName,
&node.City,
&node.Latitude,
&node.Longitude,
&node.DateEntered,
&node.LastChecked,
&node.LastCheckStatus,
&node.CORSCapable)
if err != nil {
return nodes, err
}
nodes.Items = append(nodes.Items, &node)
}
return nodes, nil
} }
type QueryLogs struct { type QueryLogs struct {

View file

@ -0,0 +1,115 @@
package monero
import (
"os"
"reflect"
"strconv"
"testing"
"xmr-remote-nodes/internal/config"
"xmr-remote-nodes/internal/database"
)
var testMySQL = true
// TODO: Add database test table and then clean it up
func init() {
// load test db config from OS environment variable
//
// Example:
// TEST_DB_HOST=127.0.0.1 \
// TEST_DB_PORT=3306 \
// TEST_DB_USER=testuser \
// TEST_DB_PASSWORD=testpass \
// TEST_DB_NAME=testdb go test ./... -v
//
// To run benchmark only, add `-bench=. -run=^#` to the `go test` command
config.DBCfg().Host = os.Getenv("TEST_DB_HOST")
config.DBCfg().Port, _ = strconv.Atoi(os.Getenv("TEST_DB_PORT"))
config.DBCfg().User = os.Getenv("TEST_DB_USER")
config.DBCfg().Password = os.Getenv("TEST_DB_PASSWORD")
config.DBCfg().Name = os.Getenv("TEST_DB_NAME")
if err := database.ConnectDB(); err != nil {
testMySQL = false
}
}
func TestQueryNodes_toSQL(t *testing.T) {
tests := []struct {
name string
query QueryNodes
wantArgs []interface{}
wantWhere string
wantSortBy string
wantSortDirection string
}{
{
name: "Default query",
query: QueryNodes{
Host: "",
Nettype: "any",
Protocol: "any",
CC: "any",
Status: -1,
CORS: -1,
RowsPerPage: 10,
Page: 1,
SortBy: "last_checked",
SortDirection: "desc",
},
wantArgs: []interface{}{},
wantWhere: "",
wantSortBy: "last_checked",
wantSortDirection: "DESC",
},
{
name: "With host query",
query: QueryNodes{
Host: "test",
Nettype: "any",
Protocol: "any",
CC: "any",
Status: -1,
CORS: -1,
RowsPerPage: 10,
Page: 1,
SortBy: "last_checked",
SortDirection: "desc",
},
wantArgs: []interface{}{"%test%", "%test%"},
wantWhere: "WHERE (hostname LIKE ? OR ip_addr LIKE ?)",
wantSortBy: "last_checked",
wantSortDirection: "DESC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotArgs, gotWhere, gotSortBy, gotSortDirection := tt.query.toSQL()
if !equalArgs(gotArgs, tt.wantArgs) {
t.Errorf("QueryNodes.toSQL() gotArgs = %v, want %v", gotArgs, tt.wantArgs)
}
if gotWhere != tt.wantWhere {
t.Errorf("QueryNodes.toSQL() gotWhere = %v, want %v", gotWhere, tt.wantWhere)
}
if gotSortBy != tt.wantSortBy {
t.Errorf("QueryNodes.toSQL() gotSortBy = %v, want %v", gotSortBy, tt.wantSortBy)
}
if gotSortDirection != tt.wantSortDirection {
t.Errorf("QueryNodes.toSQL() gotSortDirection = %v, want %v", gotSortDirection, tt.wantSortDirection)
}
})
}
}
func equalArgs(a, b []interface{}) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if !reflect.DeepEqual(v, b[i]) {
return false
}
}
return true
}

View file

@ -1,43 +1,13 @@
package monero package monero
import ( import (
"fmt"
"os"
"strconv"
"testing" "testing"
"xmr-remote-nodes/internal/config"
"xmr-remote-nodes/internal/database"
) )
var testMySQL = true
func init() {
// load test db config from OS environment variable
//
// Example:
// TEST_DB_HOST=127.0.0.1 \
// TEST_DB_PORT=3306 \
// TEST_DB_USER=testuser \
// TEST_DB_PASSWORD=testpass \
// TEST_DB_NAME=testdb go test ./... -v
//
// To run benchmark only, add `-bench=. -run=^#` to the `go test` command
config.DBCfg().Host = os.Getenv("TEST_DB_HOST")
config.DBCfg().Port, _ = strconv.Atoi(os.Getenv("TEST_DB_PORT"))
config.DBCfg().User = os.Getenv("TEST_DB_USER")
config.DBCfg().Password = os.Getenv("TEST_DB_PASSWORD")
config.DBCfg().Name = os.Getenv("TEST_DB_NAME")
if err := database.ConnectDB(); err != nil {
testMySQL = false
}
}
// TODO: Add database test table and then clean it up // TODO: Add database test table and then clean it up
func TestProberRepo_CheckApi(t *testing.T) { func TestProberRepo_CheckApi(t *testing.T) {
if !testMySQL { if !testMySQL {
fmt.Println("Skip test, not connected to database")
t.Skip("Skip test, not connected to database") t.Skip("Skip test, not connected to database")
} }
tests := []struct { tests := []struct {
@ -74,7 +44,6 @@ func TestProberRepo_CheckApi(t *testing.T) {
func BenchmarkProberRepo_CheckApi(b *testing.B) { func BenchmarkProberRepo_CheckApi(b *testing.B) {
if !testMySQL { if !testMySQL {
fmt.Println("Skip bench, not connected to database")
b.Skip("Skip bench, not connected to database") b.Skip("Skip bench, not connected to database")
} }
repo := NewProber() repo := NewProber()