Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make result set an iterable #1823

Merged
merged 3 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import akka.actor.{Actor, ActorLogging, Props}
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.NodeParams
import fr.acinq.eclair.channel.Helpers.Closing.{ClosingType, CurrentRemoteClose, LocalClose, MutualClose, NextRemoteClose, RecoveryClose, RevokedClose}
import fr.acinq.eclair.channel.Helpers.Closing._
import fr.acinq.eclair.channel.Monitoring.{Metrics => ChannelMetrics, Tags => ChannelTags}
import fr.acinq.eclair.channel._
import fr.acinq.eclair.db.DbEventHandler.ChannelEvent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package fr.acinq.eclair.db

import java.io.Closeable

import fr.acinq.eclair.blockchain.fee.FeeratesPerKB

import java.io.Closeable

/**
* This database stores the fee rates retrieved by a [[fr.acinq.eclair.blockchain.fee.FeeProvider]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@

package fr.acinq.eclair.db

import java.io.File
import java.nio.file.{Files, StandardCopyOption}

import akka.actor.{Actor, ActorLogging, Props}
import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue}
import fr.acinq.eclair.KamonExt
import fr.acinq.eclair.channel.ChannelPersisted
import fr.acinq.eclair.db.Databases.FileBackup
import fr.acinq.eclair.db.Monitoring.Metrics

import java.io.File
import java.nio.file.{Files, StandardCopyOption}
import scala.sys.process.Process
import scala.util.{Failure, Success, Try}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

package fr.acinq.eclair.db

import java.io.Closeable

import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}

import java.io.Closeable
import scala.collection.immutable.SortedMap

trait NetworkDb extends Closeable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

package fr.acinq.eclair.db

import java.io.Closeable
import java.util.UUID

import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.payment._
import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop}
import fr.acinq.eclair.{MilliSatoshi, ShortChannelId}

import java.io.Closeable
import java.util.UUID

trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable

trait IncomingPaymentsDb {
Expand Down
4 changes: 2 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

package fr.acinq.eclair.db

import java.io.Closeable

import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.wire.protocol.NodeAddress

import java.io.Closeable

trait PeersDb extends Closeable {

def addOrUpdatePeer(nodeId: PublicKey, address: NodeAddress): Unit
Expand Down
42 changes: 24 additions & 18 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ package fr.acinq.eclair.db.jdbc
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.MilliSatoshi
import org.sqlite.SQLiteConnection
import scodec.Codec
import scodec.Decoder
import scodec.bits.{BitVector, ByteVector}

import java.sql.{Connection, ResultSet, Statement, Timestamp}
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue

trait JdbcUtils {

import ExtendedResultSet._

def withConnection[T](f: Connection => T)(implicit dataSource: DataSource): T = {
val connection = dataSource.getConnection()
try {
Expand Down Expand Up @@ -72,15 +73,16 @@ trait JdbcUtils {
def getVersion(statement: Statement, db_name: String): Option[Int] = {
createVersionTable(statement)
// if there was a previous version installed, this will return a different value from current version
val rs = statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
if (rs.next()) Some(rs.getInt("version")) else None
statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'")
.map(rs => rs.getInt("version"))
.headOption
}

/**
* Updates the version for a particular logical database, it will overwrite the previous version.
*
* NB: we could define this method in [[fr.acinq.eclair.db.sqlite.SqliteUtils]] and [[fr.acinq.eclair.db.pg.PgUtils]]
* but it would make testing more complicated because we need to use one or the other depending on the backend.
* but it would make testing more complicated because we need to use one or the other depending on the backend.
*/
def setVersion(statement: Statement, db_name: String, newVersion: Int): Unit = {
createVersionTable(statement)
Expand All @@ -96,20 +98,25 @@ trait JdbcUtils {
}
}

/**
* This helper assumes that there is a "data" column available, decodable with the provided codec
*
* TODO: we should use an scala.Iterator instead
*/
def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = {
var q: Queue[T] = Queue()
while (rs.next()) {
q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value
case class ExtendedResultSet(rs: ResultSet) extends Iterable[ResultSet] {

/**
* Iterates over all rows of a result set.
*
* Careful: the iterator is lazy, it must be materialized before the [[ResultSet]] is closed, by converting the end
* result in a collection or an option.
*/
override def iterator: Iterator[ResultSet] = {
// @formatter:off
new Iterator[ResultSet] {
def hasNext: Boolean = rs.next()
def next(): ResultSet = rs
}
// @formatter:on
}
q
}

case class ExtendedResultSet(rs: ResultSet) {
/** This helper assumes that there is a "data" column available, that can be decoded with the provided codec */
def mapCodec[T](codec: Decoder[T]): Iterable[T] = rs.map(rs => codec.decode(BitVector(rs.getBytes("data"))).require.value)

def getByteVectorFromHex(columnLabel: String): ByteVector = {
val s = rs.getString(columnLabel).stripPrefix("\\x")
Expand Down Expand Up @@ -166,7 +173,6 @@ trait JdbcUtils {
val result = rs.getTimestamp(label)
if (rs.wasNull()) None else Some(result)
}

}

object ExtendedResultSet {
Expand Down
162 changes: 76 additions & 86 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import java.sql.{Statement, Timestamp}
import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.collection.immutable.Queue

class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {

Expand Down Expand Up @@ -215,30 +214,28 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var sentByParentId = Map.empty[UUID, PaymentSent]
while (rs.next()) {
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
val part = PaymentSent.PartialPayment(
UUID.fromString(rs.getString("payment_id")),
MilliSatoshi(rs.getLong("amount_msat")),
MilliSatoshi(rs.getLong("fees_msat")),
rs.getByteVector32FromHex("to_channel_id"),
None, // we don't store the route in the audit DB
rs.getTimestamp("timestamp").getTime)
val sent = sentByParentId.get(parentId) match {
case Some(s) => s.copy(parts = s.parts :+ part)
case None => PaymentSent(
parentId,
rs.getByteVector32FromHex("payment_hash"),
rs.getByteVector32FromHex("payment_preimage"),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
Seq(part))
}
sentByParentId = sentByParentId + (parentId -> sent)
}
sentByParentId.values.toSeq.sortBy(_.timestamp)
statement.executeQuery()
.foldLeft(Map.empty[UUID, PaymentSent]) { (sentByParentId, rs) =>
val parentId = UUID.fromString(rs.getString("parent_payment_id"))
val part = PaymentSent.PartialPayment(
UUID.fromString(rs.getString("payment_id")),
MilliSatoshi(rs.getLong("amount_msat")),
MilliSatoshi(rs.getLong("fees_msat")),
rs.getByteVector32FromHex("to_channel_id"),
None, // we don't store the route in the audit DB
rs.getTimestamp("timestamp").getTime)
val sent = sentByParentId.get(parentId) match {
case Some(s) => s.copy(parts = s.parts :+ part)
case None => PaymentSent(
parentId,
rs.getByteVector32FromHex("payment_hash"),
rs.getByteVector32FromHex("payment_preimage"),
MilliSatoshi(rs.getLong("recipient_amount_msat")),
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
Seq(part))
}
sentByParentId + (parentId -> sent)
}.values.toSeq.sortBy(_.timestamp)
}
}

Expand All @@ -247,98 +244,91 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var receivedByHash = Map.empty[ByteVector32, PaymentReceived]
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = PaymentReceived.PartialPayment(
MilliSatoshi(rs.getLong("amount_msat")),
rs.getByteVector32FromHex("from_channel_id"),
rs.getTimestamp("timestamp").getTime)
val received = receivedByHash.get(paymentHash) match {
case Some(r) => r.copy(parts = r.parts :+ part)
case None => PaymentReceived(paymentHash, Seq(part))
}
receivedByHash = receivedByHash + (paymentHash -> received)
}
receivedByHash.values.toSeq.sortBy(_.timestamp)
statement.executeQuery()
.foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) =>
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = PaymentReceived.PartialPayment(
MilliSatoshi(rs.getLong("amount_msat")),
rs.getByteVector32FromHex("from_channel_id"),
rs.getTimestamp("timestamp").getTime)
val received = receivedByHash.get(paymentHash) match {
case Some(r) => r.copy(parts = r.parts :+ part)
case None => PaymentReceived(paymentHash, Seq(part))
}
receivedByHash + (paymentHash -> received)
}.values.toSeq.sortBy(_.timestamp)
}
}

override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] =
inTransaction { pg =>
var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]
using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
val trampolineByHash = using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val amount = MilliSatoshi(rs.getLong("amount_msat"))
val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id"))
trampolineByHash += (paymentHash -> (amount, nodeId))
}
statement.executeQuery()
.foldLeft(Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]) { (trampolineByHash, rs) =>
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val amount = MilliSatoshi(rs.getLong("amount_msat"))
val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id"))
trampolineByHash + (paymentHash -> (amount, nodeId))
}
}
using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
val relayedByHash = using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]]
while (rs.next()) {
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = RelayedPart(
rs.getByteVector32FromHex("channel_id"),
MilliSatoshi(rs.getLong("amount_msat")),
rs.getString("direction"),
rs.getString("relay_type"),
rs.getTimestamp("timestamp").getTime)
relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
}
relayedByHash.flatMap {
case (paymentHash, parts) =>
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
parts.headOption match {
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
}
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) =>
val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey))
TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil
case _ => Nil
}
}.toSeq.sortBy(_.timestamp)
statement.executeQuery()
.foldLeft(Map.empty[ByteVector32, Seq[RelayedPart]]) { (relayedByHash, rs) =>
val paymentHash = rs.getByteVector32FromHex("payment_hash")
val part = RelayedPart(
rs.getByteVector32FromHex("channel_id"),
MilliSatoshi(rs.getLong("amount_msat")),
rs.getString("direction"),
rs.getString("relay_type"),
rs.getTimestamp("timestamp").getTime)
relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part))
}
}
relayedByHash.flatMap {
case (paymentHash, parts) =>
// We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel).
// NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch.
val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount)
parts.headOption match {
case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map {
case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp)
}
case Some(RelayedPart(_, _, _, "trampoline", timestamp)) =>
val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey))
TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil
case _ => Nil
}
}.toSeq.sortBy(_.timestamp)
}

override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] =
inTransaction { pg =>
using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement =>
statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from)))
statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to)))
val rs = statement.executeQuery()
var q: Queue[NetworkFee] = Queue()
while (rs.next()) {
q = q :+ NetworkFee(
statement.executeQuery().map { rs =>
NetworkFee(
remoteNodeId = PublicKey(rs.getByteVectorFromHex("node_id")),
channelId = rs.getByteVector32FromHex("channel_id"),
txId = rs.getByteVector32FromHex("tx_id"),
fee = Satoshi(rs.getLong("fee_sat")),
txType = rs.getString("tx_type"),
timestamp = rs.getTimestamp("timestamp").getTime)
}
q
}.toSeq
}
}

override def stats(from: Long, to: Long): Seq[Stats] = {
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) =>
val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) =>
feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee))
}
case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String)
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) =>
val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { (previous, e) =>
// NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones.
val current = e match {
case c: ChannelPaymentRelayed => Map(
Expand Down
Loading