diff --git a/.circleci/Dockerfile-jackal b/.circleci/Dockerfile-jackal new file mode 100644 index 000000000..ec2a512ec --- /dev/null +++ b/.circleci/Dockerfile-jackal @@ -0,0 +1,6 @@ +FROM ubuntu:18.04 + +# Jackal +RUN apt-get update && apt-get install --assume-yes sudo mysql-client +COPY ./setup/wait_for_db.sh /root/ +COPY testdata/muc/server/ /root/ \ No newline at end of file diff --git a/.circleci/Dockerfile-profanity b/.circleci/Dockerfile-profanity index f2332c79a..dd0946b3c 100644 --- a/.circleci/Dockerfile-profanity +++ b/.circleci/Dockerfile-profanity @@ -1,7 +1,12 @@ -FROM ubuntu:18.04 +FROM archlinux:base-20211024.0.37588 + +# a hack required to install profanity, glibc needs to be downgraded, the issue is with the arch docker image +RUN patched_glibc=glibc-linux4-2.33-4-x86_64.pkg.tar.zst && \ +curl -LO "https://repo.archlinuxcn.org/x86_64/$patched_glibc" && \ +bsdtar -C / -xvf "$patched_glibc" # Install base -RUN apt-get update && apt-get install --assume-yes profanity +RUN pacman -Sy --noconfirm profanity # Set up the profanity account RUN mkdir --parent /root/.local/share/profanity diff --git a/.circleci/config.yml b/.circleci/config.yml index 0b24cd76b..92932fd9c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -20,10 +20,8 @@ version: 2 jobs: unit_tests: docker: - - image: circleci/golang:1.12 + - image: circleci/golang:1.14 working_directory: /go/src/github.com/ortuman/jackal - environment: - GO111MODULE: "on" steps: - checkout @@ -33,12 +31,11 @@ jobs: set -xe go test ./... - integration: + scion: docker: - - image: circleci/golang:1.12 + - image: circleci/golang:1.14 working_directory: ~/repo environment: - GO111MODULE: "on" BASH_ENV: "~/repo/.circleci/bash_env.sh" JACKAL_DIR: "/root" coreAS1301IP: 172.31.0.110 @@ -52,12 +49,12 @@ jobs: docker_layer_caching: false #not available in the free plan - run: - name: build + name: Build jackal command: | set -xe make build - # Start integration tests + # Start scion integration tests - run: name: Build containers command: | @@ -65,7 +62,7 @@ jobs: sudo -E docker-compose -f .circleci/docker-compose.yml kill # stop containers sudo -E docker-compose -f .circleci/docker-compose.yml down # bring composition down sudo -E docker-compose -f .circleci/docker-compose.yml build - + - run: name: Start AS Containers command: | @@ -132,9 +129,79 @@ jobs: set -xe sudo -E docker exec profanity1 /bin/bash -c 'until grep -q "user2@server2\.xmpp" /root/.local/share/profanity/logs/profanity.log; do cat /root/.local/share/profanity/logs/profanity.log; sleep 1; done;' + groupchat: + docker: + - image: circleci/golang:1.14 + working_directory: ~/repo + environment: + BASH_ENV: "~/repo/.circleci/bash_env.sh" + JACKAL_DIR: "/root" + SERVER_IP: 172.30.0.110 + + steps: + - checkout + + - setup_remote_docker: + docker_layer_caching: false #not available in the free plan + + - run: + name: Build jackal + command: | + set -xe + make build + + - run: + name: Build containers + command: | + set -xe + sudo -E docker-compose -f .circleci/docker-compose-muc.yml kill + sudo -E docker-compose -f .circleci/docker-compose-muc.yml down + sudo -E docker-compose -f .circleci/docker-compose-muc.yml build + + - run: + name: Start server container + command: | + set -xe + sudo -E docker-compose -f .circleci/docker-compose-muc.yml up --no-start jackal_server + sudo -E docker cp ./jackal jackal_server:/root/jackal + sudo -E docker cp ./.circleci/setup/muc/start_server.sh jackal_server:/root/start_server.sh + sudo -E docker-compose -f .circleci/docker-compose-muc.yml up --no-recreate -d jackal_server + + - run: + name: Start jackal server + command: | + set -xe + sudo -E docker exec jackal_server /bin/bash -c '/root/start_server.sh;' + sudo -E docker exec jackal_server /bin/bash -c "until grep -q 'listening\ at\ 0.0.0.0\:5222' ${JACKAL_DIR}/jackal.log; do sleep 1; done;" + + - run: + name: Start client containers + command: | + set -xe + for c in client_owner client_admin; do + sudo -E docker-compose -f .circleci/docker-compose-muc.yml up --no-start $c + sudo -E docker cp ./.circleci/setup/muc/setup_client.sh $c:/root/setup_client.sh + sudo -E docker-compose -f .circleci/docker-compose-muc.yml up --no-recreate -d $c + sudo -E docker exec $c /bin/bash -c '/root/setup_client.sh;' + done + + - run: + name: Owner connects and creates a room + command: | + set -xe + sudo -E docker exec -it -d client_owner profanity -a alice + + - run: + name: Admin connects, joins the room and dms the owner from the room + command: | + set -xe + sudo -E docker exec -it -d client_admin profanity -a bob + sudo -E docker exec client_owner /bin/bash -c 'until grep -q "bob\:\ Hi" /root/.local/share/profanity/chatlogs/alice_at_muc_server.xmpp/netsec_at_conference.muc_server.xmpp_bob/*\.log; do cat /root/.local/share/profanity/alice_at_muc_server.xmpp/netsec_at_conference.muc_server.xmpp_bob/*\.log; sleep 1; done;' + workflows: version: 2 tests: jobs: - unit_tests - - integration + - scion + - groupchat diff --git a/.circleci/docker-compose-muc.yml b/.circleci/docker-compose-muc.yml new file mode 100644 index 000000000..a4e116d9c --- /dev/null +++ b/.circleci/docker-compose-muc.yml @@ -0,0 +1,69 @@ +version: "2" + +networks: + muc_net: + driver: bridge + ipam: + driver: default + config: + - subnet: 172.30.0.0/16 + +services: + jackal_db: + image: mysql:5.7 + restart: always + environment: + MYSQL_ROOT_PASSWORD: 'password' + networks: + - muc_net + + jackal_server: + build: + context: . + dockerfile: Dockerfile-jackal + container_name: jackal_server + depends_on: + - jackal_db + networks: + muc_net: + ipv4_address: ${SERVER_IP} + tty: true + privileged: true + links: + - jackal_db:mysql_host + + client_owner: + build: + context: . + dockerfile: Dockerfile-profanity + container_name: client_owner + networks: + - muc_net + environment: + CLIENT_ID: owner + tty: true + entrypoint: /bin/bash + + client_admin: + build: + context: . + dockerfile: Dockerfile-profanity + container_name: client_admin + networks: + - muc_net + environment: + CLIENT_ID: admin + tty: true + entrypoint: /bin/bash + + client_member: + build: + context: . + dockerfile: Dockerfile-profanity + container_name: client_member + networks: + - muc_net + environment: + CLIENT_ID: member + tty: true + entrypoint: /bin/bash diff --git a/.circleci/setup/install_client.sh b/.circleci/setup/install_client.sh index 29949577a..41a3fe0f0 100755 --- a/.circleci/setup/install_client.sh +++ b/.circleci/setup/install_client.sh @@ -5,11 +5,9 @@ set -xe if [ "$CLIENT_ID" == "client1" ]; then cp -r /root/c1_data/profanity/* /root/.local/share/profanity/ echo "172.31.0.111 server1.xmpp" >> /etc/hosts - cp /root/c1_data/server1.xmpp.crt /usr/local/share/ca-certificates/ - update-ca-certificates + trust anchor --store /root/c1_data/server1.xmpp.crt else cp -r /root/c2_data/profanity/* /root/.local/share/profanity/ echo "172.31.0.112 server2.xmpp" >> /etc/hosts - cp /root/c2_data/server2.xmpp.crt /usr/local/share/ca-certificates/ - update-ca-certificates + trust anchor --store /root/c2_data/server2.xmpp.crt fi diff --git a/.circleci/setup/muc/setup_client.sh b/.circleci/setup/muc/setup_client.sh new file mode 100755 index 000000000..80ace3d23 --- /dev/null +++ b/.circleci/setup/muc/setup_client.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -xe + +# Set up the profanity account +cp -r /root/muc/client_${CLIENT_ID}_data/profanity/* /root/.local/share/profanity/ + +# Add DNS records for the xmpp and muc servers +echo "172.30.0.110 muc_server.xmpp" >> /etc/hosts +echo "172.30.0.110 conference.muc_server.xmpp" >> /etc/hosts + +# Install the server certificate +trust anchor --store /root/muc/server/ssl/muc_server.xmpp.crt diff --git a/.circleci/setup/muc/start_server.sh b/.circleci/setup/muc/start_server.sh new file mode 100755 index 000000000..a62bcc772 --- /dev/null +++ b/.circleci/setup/muc/start_server.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -ex +BASE=$(dirname $0) +cd $BASE + +./wait_for_db.sh + +# create database +echo "GRANT ALL ON jackal.* TO 'jackal'@ IDENTIFIED BY 'password';" | mysql -h mysql_host -u root -ppassword +echo "CREATE DATABASE jackal;" | mysql -h mysql_host -u jackal -ppassword + +# load user data +mysql -h mysql_host -D jackal -u jackal -ppassword < muc_data.sql + +# start jackal +./jackal -c muc_jackal.yml jackal.stdout & diff --git a/.circleci/setup/wait_for_db.sh b/.circleci/setup/wait_for_db.sh index 195219e53..711e7d30a 100755 --- a/.circleci/setup/wait_for_db.sh +++ b/.circleci/setup/wait_for_db.sh @@ -7,7 +7,7 @@ counter=1 while ! mysql --protocol TCP -h mysql_host -u root -ppassword -e "show databases;" > /dev/null 2>&1; do sleep 1 ((counter++)) - if [ $counter -gt 10 ]; then + if [ $counter -gt 60 ]; then >&2 echo "We have been waiting for MySQL too long already; failing." exit 1 fi diff --git a/.circleci/testdata/muc/client_admin_data/muc_server.xmpp.crt b/.circleci/testdata/muc/client_admin_data/muc_server.xmpp.crt new file mode 100644 index 000000000..4a880566c --- /dev/null +++ b/.circleci/testdata/muc/client_admin_data/muc_server.xmpp.crt @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEyDCCAzCgAwIBAgIUEMCpBI00Dsni3OXfXTQudKAdEO4wDQYJKoZIhvcNAQEL +BQAwXzELMAkGA1UEBhMCVVMxEDAOBgNVBAgMB3ByaXZhdGUxETAPBgNVBAcMCHBy +b3ZpbmNlMQ0wCwYDVQQKDARjaXR5MRwwGgYDVQQDDBNtdWNfc2VydmVyLnhtcHAu +a2V5MB4XDTIxMTAyNTE0MjIzMFoXDTIzMTAyNTE0MjIzMFowXzELMAkGA1UEBhMC +VVMxEDAOBgNVBAgMB3ByaXZhdGUxETAPBgNVBAcMCHByb3ZpbmNlMQ0wCwYDVQQK +DARjaXR5MRwwGgYDVQQDDBNtdWNfc2VydmVyLnhtcHAua2V5MIIBojANBgkqhkiG +9w0BAQEFAAOCAY8AMIIBigKCAYEApPHgFN8g7OM6jpnC+VZyjG2jksDfwQ5lqB9y +NjBwsunjN7cayUXeTWOzdZLtfCH4umDaAn0eTBocDLmAW8KoQoZ3wtvZ6XeybVHH +I+zpbrm6he8dg4Raz04vfGCAjLE+0fWe2oS11gBhIgdXPPlMYRNd6M/Q5gVJHLNr +IQa8OkHDOunjEfCOghm+T1LDsPO+P3aqVT57CXBPnIujZdsm3+fVF7lmqkdS/ED4 +8weUGPVhoPF3tn+u7xupAoHRAn0M6/+CuXB34BgLuj0KtArIDlSQpGROgRGVCg1d +gPjAMJO476ZUIClk7r7WHDEBAnGfer1bm+6pqzEafK4JWZl+hOYHrRi77nwhmzNj +H4ClwXEleLQERQRI/d2Dmv5zDAbZZ/jXte2ZVjhZ/5heCtlMMf/Q2/Wtb6BHi1GT +jGInBsbU6M+jFLuTAUlnREIWg/c7HC1L0IpEk0nDEpRRrfQ4UI73vRFZ19IDhofR +HlDmVjJHos1AnXdwkUmZKO+mrvC7AgMBAAGjfDB6MB0GA1UdDgQWBBQA4X+kZ1Ve +84r/skjK0dMVNxDetTAfBgNVHSMEGDAWgBQA4X+kZ1Ve84r/skjK0dMVNxDetTAP +BgNVHRMBAf8EBTADAQH/MAsGA1UdDwQEAwIFoDAaBgNVHREEEzARgg9tdWNfc2Vy +dmVyLnhtcHAwDQYJKoZIhvcNAQELBQADggGBAC9Z1AlaNod6XUJOT8EybUAzEslV +6Np7NFgzDK5tfCNWDdYQyO3Yb5gTTA7kz+DIU0fff6QQaXkT6kMMh7WFtb2tc6Wi +mq6fSJ2dolBP2/Cxv59LHEO54ZiMBRrpbaA+6HqGKPTJcWOx87iExZNvlb9gQPfz +8jx55+7lix7+IVid2wAjCZ5UE6+0vseffkBNMJ3lxOAaNJO/kp10w5vG3yz7Q8Vh +H4iategJ2PFPNGijHr5hdxa3zji6rojxDA4CyGH10VvxdZEsZnx4Moh+CfUCgsaC +K3iUrNu1XMkyuROpYekVOdYMTVnkZ2jYhbFNzzm0awq11zk0Di8D71xfgsHVeIbo +8IHeDeRJxMKcK1jYbjoskzvh/vTEZK0KOK8qw/RFxQj9XP0wEVS8b7M2qkEUQbBj +5MZv36ruwcWYNUdXI945Ie4lLGrlO6ysQhqgZgny8LqwT69w4LFOvdG48Xx3OJ8J +GDYeGv1DEU9LOiMlO45w1TGbyMS7fpKfVqLx6g== +-----END CERTIFICATE----- diff --git a/.circleci/testdata/muc/client_admin_data/profanity/accounts b/.circleci/testdata/muc/client_admin_data/profanity/accounts new file mode 100755 index 000000000..50e273fe7 --- /dev/null +++ b/.circleci/testdata/muc/client_admin_data/profanity/accounts @@ -0,0 +1,14 @@ +[bob] +enabled=true +jid=bob@muc_server.xmpp +resource=profanity +muc.nick=bob +presence.last=online +presence.login=online +priority.online=0 +priority.chat=0 +priority.away=0 +priority.xa=0 +priority.dnd=0 +password=asdf +script.start=admin_script diff --git a/.circleci/testdata/muc/client_admin_data/profanity/plugin_settings b/.circleci/testdata/muc/client_admin_data/profanity/plugin_settings new file mode 100755 index 000000000..e69de29bb diff --git a/.circleci/testdata/muc/client_admin_data/profanity/plugin_themes b/.circleci/testdata/muc/client_admin_data/profanity/plugin_themes new file mode 100755 index 000000000..e69de29bb diff --git a/.circleci/testdata/muc/client_admin_data/profanity/scripts/admin_script b/.circleci/testdata/muc/client_admin_data/profanity/scripts/admin_script new file mode 100644 index 000000000..ba47d9197 --- /dev/null +++ b/.circleci/testdata/muc/client_admin_data/profanity/scripts/admin_script @@ -0,0 +1,2 @@ +/join netsec@conference.muc_server.xmpp +/msg netsec@conference.muc_server.xmpp/alice Hi \ No newline at end of file diff --git a/.circleci/testdata/muc/client_owner_data/muc_server.xmpp.crt b/.circleci/testdata/muc/client_owner_data/muc_server.xmpp.crt new file mode 100644 index 000000000..4a880566c --- /dev/null +++ b/.circleci/testdata/muc/client_owner_data/muc_server.xmpp.crt @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEyDCCAzCgAwIBAgIUEMCpBI00Dsni3OXfXTQudKAdEO4wDQYJKoZIhvcNAQEL +BQAwXzELMAkGA1UEBhMCVVMxEDAOBgNVBAgMB3ByaXZhdGUxETAPBgNVBAcMCHBy +b3ZpbmNlMQ0wCwYDVQQKDARjaXR5MRwwGgYDVQQDDBNtdWNfc2VydmVyLnhtcHAu +a2V5MB4XDTIxMTAyNTE0MjIzMFoXDTIzMTAyNTE0MjIzMFowXzELMAkGA1UEBhMC +VVMxEDAOBgNVBAgMB3ByaXZhdGUxETAPBgNVBAcMCHByb3ZpbmNlMQ0wCwYDVQQK +DARjaXR5MRwwGgYDVQQDDBNtdWNfc2VydmVyLnhtcHAua2V5MIIBojANBgkqhkiG +9w0BAQEFAAOCAY8AMIIBigKCAYEApPHgFN8g7OM6jpnC+VZyjG2jksDfwQ5lqB9y +NjBwsunjN7cayUXeTWOzdZLtfCH4umDaAn0eTBocDLmAW8KoQoZ3wtvZ6XeybVHH +I+zpbrm6he8dg4Raz04vfGCAjLE+0fWe2oS11gBhIgdXPPlMYRNd6M/Q5gVJHLNr +IQa8OkHDOunjEfCOghm+T1LDsPO+P3aqVT57CXBPnIujZdsm3+fVF7lmqkdS/ED4 +8weUGPVhoPF3tn+u7xupAoHRAn0M6/+CuXB34BgLuj0KtArIDlSQpGROgRGVCg1d +gPjAMJO476ZUIClk7r7WHDEBAnGfer1bm+6pqzEafK4JWZl+hOYHrRi77nwhmzNj +H4ClwXEleLQERQRI/d2Dmv5zDAbZZ/jXte2ZVjhZ/5heCtlMMf/Q2/Wtb6BHi1GT +jGInBsbU6M+jFLuTAUlnREIWg/c7HC1L0IpEk0nDEpRRrfQ4UI73vRFZ19IDhofR +HlDmVjJHos1AnXdwkUmZKO+mrvC7AgMBAAGjfDB6MB0GA1UdDgQWBBQA4X+kZ1Ve +84r/skjK0dMVNxDetTAfBgNVHSMEGDAWgBQA4X+kZ1Ve84r/skjK0dMVNxDetTAP +BgNVHRMBAf8EBTADAQH/MAsGA1UdDwQEAwIFoDAaBgNVHREEEzARgg9tdWNfc2Vy +dmVyLnhtcHAwDQYJKoZIhvcNAQELBQADggGBAC9Z1AlaNod6XUJOT8EybUAzEslV +6Np7NFgzDK5tfCNWDdYQyO3Yb5gTTA7kz+DIU0fff6QQaXkT6kMMh7WFtb2tc6Wi +mq6fSJ2dolBP2/Cxv59LHEO54ZiMBRrpbaA+6HqGKPTJcWOx87iExZNvlb9gQPfz +8jx55+7lix7+IVid2wAjCZ5UE6+0vseffkBNMJ3lxOAaNJO/kp10w5vG3yz7Q8Vh +H4iategJ2PFPNGijHr5hdxa3zji6rojxDA4CyGH10VvxdZEsZnx4Moh+CfUCgsaC +K3iUrNu1XMkyuROpYekVOdYMTVnkZ2jYhbFNzzm0awq11zk0Di8D71xfgsHVeIbo +8IHeDeRJxMKcK1jYbjoskzvh/vTEZK0KOK8qw/RFxQj9XP0wEVS8b7M2qkEUQbBj +5MZv36ruwcWYNUdXI945Ie4lLGrlO6ysQhqgZgny8LqwT69w4LFOvdG48Xx3OJ8J +GDYeGv1DEU9LOiMlO45w1TGbyMS7fpKfVqLx6g== +-----END CERTIFICATE----- diff --git a/.circleci/testdata/muc/client_owner_data/profanity/accounts b/.circleci/testdata/muc/client_owner_data/profanity/accounts new file mode 100755 index 000000000..4e59f9518 --- /dev/null +++ b/.circleci/testdata/muc/client_owner_data/profanity/accounts @@ -0,0 +1,14 @@ +[alice] +enabled=true +jid=alice@muc_server.xmpp +resource=profanity +muc.nick=alice +presence.last=online +presence.login=online +priority.online=0 +priority.chat=0 +priority.away=0 +priority.xa=0 +priority.dnd=0 +password=asdf +script.start=owner_script \ No newline at end of file diff --git a/.circleci/testdata/muc/client_owner_data/profanity/plugin_settings b/.circleci/testdata/muc/client_owner_data/profanity/plugin_settings new file mode 100755 index 000000000..e69de29bb diff --git a/.circleci/testdata/muc/client_owner_data/profanity/plugin_themes b/.circleci/testdata/muc/client_owner_data/profanity/plugin_themes new file mode 100755 index 000000000..e69de29bb diff --git a/.circleci/testdata/muc/client_owner_data/profanity/scripts/owner_script b/.circleci/testdata/muc/client_owner_data/profanity/scripts/owner_script new file mode 100644 index 000000000..735cbcd8f --- /dev/null +++ b/.circleci/testdata/muc/client_owner_data/profanity/scripts/owner_script @@ -0,0 +1,2 @@ +/logging chat on +/join netsec@conference.muc_server.xmpp \ No newline at end of file diff --git a/.circleci/testdata/muc/server/muc_data.sql b/.circleci/testdata/muc/server/muc_data.sql new file mode 100644 index 000000000..7e87daf73 --- /dev/null +++ b/.circleci/testdata/muc/server/muc_data.sql @@ -0,0 +1,622 @@ +-- MariaDB dump 10.19 Distrib 10.6.4-MariaDB, for Linux (x86_64) +-- +-- Host: localhost Database: jackal +-- ------------------------------------------------------ +-- Server version 10.6.4-MariaDB + +/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; +/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; +/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; +/*!40101 SET NAMES utf8mb4 */; +/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; +/*!40103 SET TIME_ZONE='+00:00' */; +/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; +/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; +/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; +/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; + +-- +-- Table structure for table `blocklist_items` +-- + +DROP TABLE IF EXISTS `blocklist_items`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `blocklist_items` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`username`,`jid`), + KEY `i_blocklist_items_username` (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `blocklist_items` +-- + +LOCK TABLES `blocklist_items` WRITE; +/*!40000 ALTER TABLE `blocklist_items` DISABLE KEYS */; +/*!40000 ALTER TABLE `blocklist_items` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `capabilities` +-- + +DROP TABLE IF EXISTS `capabilities`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `capabilities` ( + `node` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `ver` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `features` text COLLATE utf8mb4_unicode_ci DEFAULT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`node`,`ver`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `capabilities` +-- + +LOCK TABLES `capabilities` WRITE; +/*!40000 ALTER TABLE `capabilities` DISABLE KEYS */; +/*!40000 ALTER TABLE `capabilities` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `occupants` +-- + +DROP TABLE IF EXISTS `occupants`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `occupants` ( + `occupant_jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `bare_jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `affiliation` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL, + `role` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL, + PRIMARY KEY (`occupant_jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `occupants` +-- + +LOCK TABLES `occupants` WRITE; +/*!40000 ALTER TABLE `occupants` DISABLE KEYS */; +/*!40000 ALTER TABLE `occupants` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `offline_messages` +-- + +DROP TABLE IF EXISTS `offline_messages`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `offline_messages` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `data` mediumtext COLLATE utf8mb4_unicode_ci NOT NULL, + `created_at` datetime NOT NULL, + KEY `i_offline_messages_username` (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `offline_messages` +-- + +LOCK TABLES `offline_messages` WRITE; +/*!40000 ALTER TABLE `offline_messages` DISABLE KEYS */; +/*!40000 ALTER TABLE `offline_messages` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `presences` +-- + +DROP TABLE IF EXISTS `presences`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `presences` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `domain` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `resource` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `presence` text COLLATE utf8mb4_unicode_ci NOT NULL, + `node` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `ver` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `allocation_id` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`username`,`domain`,`resource`), + KEY `i_presences_username_domain` (`username`,`domain`), + KEY `i_presences_domain_resource` (`domain`,`resource`), + KEY `i_presences_allocation_id` (`allocation_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `presences` +-- + +LOCK TABLES `presences` WRITE; +/*!40000 ALTER TABLE `presences` DISABLE KEYS */; +/*!40000 ALTER TABLE `presences` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `private_storage` +-- + +DROP TABLE IF EXISTS `private_storage`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `private_storage` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `namespace` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `data` mediumtext COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`username`,`namespace`), + KEY `i_private_storage_username` (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `private_storage` +-- + +LOCK TABLES `private_storage` WRITE; +/*!40000 ALTER TABLE `private_storage` DISABLE KEYS */; +/*!40000 ALTER TABLE `private_storage` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `pubsub_affiliations` +-- + +DROP TABLE IF EXISTS `pubsub_affiliations`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `pubsub_affiliations` ( + `node_id` bigint(20) NOT NULL, + `jid` text COLLATE utf8mb4_unicode_ci NOT NULL, + `affiliation` text COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + UNIQUE KEY `i_pubsub_affiliations_node_id_jid` (`node_id`,`jid`(512)), + KEY `i_pubsub_affiliations_jid` (`jid`(512)) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `pubsub_affiliations` +-- + +LOCK TABLES `pubsub_affiliations` WRITE; +/*!40000 ALTER TABLE `pubsub_affiliations` DISABLE KEYS */; +/*!40000 ALTER TABLE `pubsub_affiliations` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `pubsub_items` +-- + +DROP TABLE IF EXISTS `pubsub_items`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `pubsub_items` ( + `node_id` bigint(20) NOT NULL, + `item_id` text COLLATE utf8mb4_unicode_ci NOT NULL, + `payload` text COLLATE utf8mb4_unicode_ci NOT NULL, + `publisher` text COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + UNIQUE KEY `i_pubsub_items_node_id_item_id` (`node_id`,`item_id`(36)), + KEY `i_pubsub_items_item_id` (`item_id`(36)), + KEY `i_pubsub_items_node_id_created_at` (`node_id`,`created_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `pubsub_items` +-- + +LOCK TABLES `pubsub_items` WRITE; +/*!40000 ALTER TABLE `pubsub_items` DISABLE KEYS */; +/*!40000 ALTER TABLE `pubsub_items` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `pubsub_node_options` +-- + +DROP TABLE IF EXISTS `pubsub_node_options`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `pubsub_node_options` ( + `node_id` bigint(20) NOT NULL, + `name` text COLLATE utf8mb4_unicode_ci NOT NULL, + `value` text COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + KEY `i_pubsub_node_options_node_id` (`node_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `pubsub_node_options` +-- + +LOCK TABLES `pubsub_node_options` WRITE; +/*!40000 ALTER TABLE `pubsub_node_options` DISABLE KEYS */; +/*!40000 ALTER TABLE `pubsub_node_options` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `pubsub_nodes` +-- + +DROP TABLE IF EXISTS `pubsub_nodes`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `pubsub_nodes` ( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `host` text COLLATE utf8mb4_unicode_ci NOT NULL, + `name` text COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `i_pubsub_nodes_host_name` (`host`(256),`name`(512)), + KEY `i_pubsub_nodes_host` (`host`(256)) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `pubsub_nodes` +-- + +LOCK TABLES `pubsub_nodes` WRITE; +/*!40000 ALTER TABLE `pubsub_nodes` DISABLE KEYS */; +/*!40000 ALTER TABLE `pubsub_nodes` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `pubsub_subscriptions` +-- + +DROP TABLE IF EXISTS `pubsub_subscriptions`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `pubsub_subscriptions` ( + `node_id` bigint(20) NOT NULL, + `subid` text COLLATE utf8mb4_unicode_ci NOT NULL, + `jid` text COLLATE utf8mb4_unicode_ci NOT NULL, + `subscription` text COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + UNIQUE KEY `i_pubsub_subscriptions_node_id_jid` (`node_id`,`jid`(512)), + KEY `i_pubsub_subscriptions_jid` (`jid`(512)) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `pubsub_subscriptions` +-- + +LOCK TABLES `pubsub_subscriptions` WRITE; +/*!40000 ALTER TABLE `pubsub_subscriptions` DISABLE KEYS */; +/*!40000 ALTER TABLE `pubsub_subscriptions` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `resources` +-- + +DROP TABLE IF EXISTS `resources`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `resources` ( + `occupant_jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `resource` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + PRIMARY KEY (`occupant_jid`,`resource`), + KEY `i_occupant_jid` (`occupant_jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `resources` +-- + +LOCK TABLES `resources` WRITE; +/*!40000 ALTER TABLE `resources` DISABLE KEYS */; +/*!40000 ALTER TABLE `resources` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `rooms` +-- + +DROP TABLE IF EXISTS `rooms`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `rooms` ( + `room_jid` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `name` text COLLATE utf8mb4_unicode_ci DEFAULT NULL, + `description` text COLLATE utf8mb4_unicode_ci DEFAULT NULL, + `subject` text COLLATE utf8mb4_unicode_ci DEFAULT NULL, + `language` text COLLATE utf8mb4_unicode_ci DEFAULT NULL, + `locked` tinyint(1) NOT NULL, + `occupants_online` int(11) NOT NULL, + PRIMARY KEY (`room_jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `rooms` +-- + +LOCK TABLES `rooms` WRITE; +/*!40000 ALTER TABLE `rooms` DISABLE KEYS */; +/*!40000 ALTER TABLE `rooms` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `rooms_config` +-- + +DROP TABLE IF EXISTS `rooms_config`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `rooms_config` ( + `room_jid` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `public` tinyint(1) NOT NULL, + `persistent` tinyint(1) NOT NULL, + `pwd_protected` tinyint(1) NOT NULL, + `password` text COLLATE utf8mb4_unicode_ci NOT NULL, + `open` tinyint(1) NOT NULL, + `moderated` tinyint(1) NOT NULL, + `allow_invites` tinyint(1) NOT NULL, + `max_occupants` int(11) NOT NULL, + `allow_subj_change` tinyint(1) NOT NULL, + `non_anonymous` tinyint(1) NOT NULL, + `can_send_pm` varchar(32) COLLATE utf8mb4_unicode_ci NOT NULL, + `can_get_member_list` varchar(32) COLLATE utf8mb4_unicode_ci NOT NULL, + PRIMARY KEY (`room_jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `rooms_config` +-- + +LOCK TABLES `rooms_config` WRITE; +/*!40000 ALTER TABLE `rooms_config` DISABLE KEYS */; +/*!40000 ALTER TABLE `rooms_config` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `rooms_invites` +-- + +DROP TABLE IF EXISTS `rooms_invites`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `rooms_invites` ( + `room_jid` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `user_jid` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + PRIMARY KEY (`room_jid`,`user_jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `rooms_invites` +-- + +LOCK TABLES `rooms_invites` WRITE; +/*!40000 ALTER TABLE `rooms_invites` DISABLE KEYS */; +/*!40000 ALTER TABLE `rooms_invites` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `rooms_users` +-- + +DROP TABLE IF EXISTS `rooms_users`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `rooms_users` ( + `room_jid` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `user_jid` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `occupant_jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + PRIMARY KEY (`room_jid`,`user_jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `rooms_users` +-- + +LOCK TABLES `rooms_users` WRITE; +/*!40000 ALTER TABLE `rooms_users` DISABLE KEYS */; +/*!40000 ALTER TABLE `rooms_users` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `roster_groups` +-- + +DROP TABLE IF EXISTS `roster_groups`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `roster_groups` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `group` text COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + KEY `i_roster_groups_username_jid` (`username`,`jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `roster_groups` +-- + +LOCK TABLES `roster_groups` WRITE; +/*!40000 ALTER TABLE `roster_groups` DISABLE KEYS */; +/*!40000 ALTER TABLE `roster_groups` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `roster_items` +-- + +DROP TABLE IF EXISTS `roster_items`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `roster_items` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `name` text COLLATE utf8mb4_unicode_ci NOT NULL, + `subscription` text COLLATE utf8mb4_unicode_ci NOT NULL, + `groups` text COLLATE utf8mb4_unicode_ci NOT NULL, + `ask` tinyint(1) NOT NULL, + `ver` int(11) NOT NULL DEFAULT 0, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`username`,`jid`), + KEY `i_roster_items_username` (`username`), + KEY `i_roster_items_jid` (`jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `roster_items` +-- + +LOCK TABLES `roster_items` WRITE; +/*!40000 ALTER TABLE `roster_items` DISABLE KEYS */; +/*!40000 ALTER TABLE `roster_items` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `roster_notifications` +-- + +DROP TABLE IF EXISTS `roster_notifications`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `roster_notifications` ( + `contact` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `jid` varchar(512) COLLATE utf8mb4_unicode_ci NOT NULL, + `elements` text COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`contact`,`jid`), + KEY `i_roster_notifications_jid` (`jid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `roster_notifications` +-- + +LOCK TABLES `roster_notifications` WRITE; +/*!40000 ALTER TABLE `roster_notifications` DISABLE KEYS */; +/*!40000 ALTER TABLE `roster_notifications` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `roster_versions` +-- + +DROP TABLE IF EXISTS `roster_versions`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `roster_versions` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `ver` int(11) NOT NULL DEFAULT 0, + `last_deletion_ver` int(11) NOT NULL DEFAULT 0, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `roster_versions` +-- + +LOCK TABLES `roster_versions` WRITE; +/*!40000 ALTER TABLE `roster_versions` DISABLE KEYS */; +/*!40000 ALTER TABLE `roster_versions` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `users` +-- + +DROP TABLE IF EXISTS `users`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `users` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `password` text COLLATE utf8mb4_unicode_ci NOT NULL, + `last_presence` text COLLATE utf8mb4_unicode_ci NOT NULL, + `last_presence_at` datetime NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `users` +-- + +LOCK TABLES `users` WRITE; +/*!40000 ALTER TABLE `users` DISABLE KEYS */; +INSERT INTO `users` VALUES ('alice','asdf','','2021-10-24 18:42:58','2021-10-24 18:42:58','2021-10-24 18:42:58'),('bob','asdf','','2021-10-24 18:42:58','2021-10-24 18:42:58','2021-10-24 18:42:58'),('carol','asdf','','2021-10-24 18:42:58','2021-10-24 18:42:58','2021-10-24 18:42:58'); +/*!40000 ALTER TABLE `users` ENABLE KEYS */; +UNLOCK TABLES; + +-- +-- Table structure for table `vcards` +-- + +DROP TABLE IF EXISTS `vcards`; +/*!40101 SET @saved_cs_client = @@character_set_client */; +/*!40101 SET character_set_client = utf8 */; +CREATE TABLE `vcards` ( + `username` varchar(256) COLLATE utf8mb4_unicode_ci NOT NULL, + `vcard` mediumtext COLLATE utf8mb4_unicode_ci NOT NULL, + `updated_at` datetime NOT NULL, + `created_at` datetime NOT NULL, + PRIMARY KEY (`username`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table `vcards` +-- + +LOCK TABLES `vcards` WRITE; +/*!40000 ALTER TABLE `vcards` DISABLE KEYS */; +/*!40000 ALTER TABLE `vcards` ENABLE KEYS */; +UNLOCK TABLES; +/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; + +/*!40101 SET SQL_MODE=@OLD_SQL_MODE */; +/*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */; +/*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */; +/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */; +/*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */; +/*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */; +/*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */; + +-- Dump completed on 2021-10-25 16:01:21 diff --git a/.circleci/testdata/muc/server/muc_jackal.yml b/.circleci/testdata/muc/server/muc_jackal.yml new file mode 100644 index 000000000..60b34f117 --- /dev/null +++ b/.circleci/testdata/muc/server/muc_jackal.yml @@ -0,0 +1,105 @@ +# jackal default configuration file + +pid_path: jackal.pid + +logger: + level: debug + log_path: jackal.log + +storage: + type: mysql + mysql: + host: mysql_host:3306 + user: jackal + password: password + database: jackal + pool_size: 16 + +hosts: + - name: muc_server.xmpp + tls: + privkey_path: "ssl/muc_server.xmpp.key" + cert_path: "ssl/muc_server.xmpp.crt" + +modules: + enabled: + - roster # Roster + - last_activity # XEP-0012: Last Activity + - private # XEP-0049: Private XML Storage + - vcard # XEP-0054: vcard-temp + - registration # XEP-0077: In-Band Registration + - version # XEP-0092: Software Version + - pep # XEP-0163: Personal Eventing Protocol + - blocking_command # XEP-0191: Blocking Command + - ping # XEP-0199: XMPP Ping + - offline # Offline storage + - muc # XEP-0045: Multi-User Chat + + mod_roster: + versioning: true + + mod_offline: + queue_size: 2500 + + mod_registration: + allow_registration: yes + allow_change: yes + allow_cancel: yes + + mod_version: + show_os: true + + mod_ping: + send: no + send_interval: 60 + + mod_muc: + host: conference.muc_server.xmpp + name: "Test Chatroom Server" + room_defaults: + public: true + persistent: true + password_protected: false + open: true + moderated: false + allow_invites: false + allow_subject_change: true + enable_logging: true + non_anonymous: true + occupant_count: -1 # -1 means don't set the limit + # options for the next ones are "all", "moderators" and "" + can_get_member_list: "all" + send_pm: "all" + +c2s: + - id: default + + connect_timeout: 5 + keep_alive: 120 + + max_stanza_size: 65536 + resource_conflict: replace # [override, replace, reject] + + transport: + type: socket # websocket + bind_addr: 0.0.0.0 + port: 5222 + # url_path: /xmpp/ws + + compression: + level: default + + sasl: + - plain + - scram_sha_1 + - scram_sha_256 + +s2s: + dial_timeout: 15 + keep_alive: 600 + dialback_secret: s3cr3tf0rd14lb4ck + max_stanza_size: 131072 + + transport: + bind_addr: 0.0.0.0 + port: 5269 diff --git a/.circleci/testdata/muc/server/ssl/muc_server.xmpp.crt b/.circleci/testdata/muc/server/ssl/muc_server.xmpp.crt new file mode 100644 index 000000000..4a880566c --- /dev/null +++ b/.circleci/testdata/muc/server/ssl/muc_server.xmpp.crt @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEyDCCAzCgAwIBAgIUEMCpBI00Dsni3OXfXTQudKAdEO4wDQYJKoZIhvcNAQEL +BQAwXzELMAkGA1UEBhMCVVMxEDAOBgNVBAgMB3ByaXZhdGUxETAPBgNVBAcMCHBy +b3ZpbmNlMQ0wCwYDVQQKDARjaXR5MRwwGgYDVQQDDBNtdWNfc2VydmVyLnhtcHAu +a2V5MB4XDTIxMTAyNTE0MjIzMFoXDTIzMTAyNTE0MjIzMFowXzELMAkGA1UEBhMC +VVMxEDAOBgNVBAgMB3ByaXZhdGUxETAPBgNVBAcMCHByb3ZpbmNlMQ0wCwYDVQQK +DARjaXR5MRwwGgYDVQQDDBNtdWNfc2VydmVyLnhtcHAua2V5MIIBojANBgkqhkiG +9w0BAQEFAAOCAY8AMIIBigKCAYEApPHgFN8g7OM6jpnC+VZyjG2jksDfwQ5lqB9y +NjBwsunjN7cayUXeTWOzdZLtfCH4umDaAn0eTBocDLmAW8KoQoZ3wtvZ6XeybVHH +I+zpbrm6he8dg4Raz04vfGCAjLE+0fWe2oS11gBhIgdXPPlMYRNd6M/Q5gVJHLNr +IQa8OkHDOunjEfCOghm+T1LDsPO+P3aqVT57CXBPnIujZdsm3+fVF7lmqkdS/ED4 +8weUGPVhoPF3tn+u7xupAoHRAn0M6/+CuXB34BgLuj0KtArIDlSQpGROgRGVCg1d +gPjAMJO476ZUIClk7r7WHDEBAnGfer1bm+6pqzEafK4JWZl+hOYHrRi77nwhmzNj +H4ClwXEleLQERQRI/d2Dmv5zDAbZZ/jXte2ZVjhZ/5heCtlMMf/Q2/Wtb6BHi1GT +jGInBsbU6M+jFLuTAUlnREIWg/c7HC1L0IpEk0nDEpRRrfQ4UI73vRFZ19IDhofR +HlDmVjJHos1AnXdwkUmZKO+mrvC7AgMBAAGjfDB6MB0GA1UdDgQWBBQA4X+kZ1Ve +84r/skjK0dMVNxDetTAfBgNVHSMEGDAWgBQA4X+kZ1Ve84r/skjK0dMVNxDetTAP +BgNVHRMBAf8EBTADAQH/MAsGA1UdDwQEAwIFoDAaBgNVHREEEzARgg9tdWNfc2Vy +dmVyLnhtcHAwDQYJKoZIhvcNAQELBQADggGBAC9Z1AlaNod6XUJOT8EybUAzEslV +6Np7NFgzDK5tfCNWDdYQyO3Yb5gTTA7kz+DIU0fff6QQaXkT6kMMh7WFtb2tc6Wi +mq6fSJ2dolBP2/Cxv59LHEO54ZiMBRrpbaA+6HqGKPTJcWOx87iExZNvlb9gQPfz +8jx55+7lix7+IVid2wAjCZ5UE6+0vseffkBNMJ3lxOAaNJO/kp10w5vG3yz7Q8Vh +H4iategJ2PFPNGijHr5hdxa3zji6rojxDA4CyGH10VvxdZEsZnx4Moh+CfUCgsaC +K3iUrNu1XMkyuROpYekVOdYMTVnkZ2jYhbFNzzm0awq11zk0Di8D71xfgsHVeIbo +8IHeDeRJxMKcK1jYbjoskzvh/vTEZK0KOK8qw/RFxQj9XP0wEVS8b7M2qkEUQbBj +5MZv36ruwcWYNUdXI945Ie4lLGrlO6ysQhqgZgny8LqwT69w4LFOvdG48Xx3OJ8J +GDYeGv1DEU9LOiMlO45w1TGbyMS7fpKfVqLx6g== +-----END CERTIFICATE----- diff --git a/.circleci/testdata/muc/server/ssl/muc_server.xmpp.key b/.circleci/testdata/muc/server/ssl/muc_server.xmpp.key new file mode 100644 index 000000000..8b7a86f95 --- /dev/null +++ b/.circleci/testdata/muc/server/ssl/muc_server.xmpp.key @@ -0,0 +1,39 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIG4wIBAAKCAYEApPHgFN8g7OM6jpnC+VZyjG2jksDfwQ5lqB9yNjBwsunjN7ca +yUXeTWOzdZLtfCH4umDaAn0eTBocDLmAW8KoQoZ3wtvZ6XeybVHHI+zpbrm6he8d +g4Raz04vfGCAjLE+0fWe2oS11gBhIgdXPPlMYRNd6M/Q5gVJHLNrIQa8OkHDOunj +EfCOghm+T1LDsPO+P3aqVT57CXBPnIujZdsm3+fVF7lmqkdS/ED48weUGPVhoPF3 +tn+u7xupAoHRAn0M6/+CuXB34BgLuj0KtArIDlSQpGROgRGVCg1dgPjAMJO476ZU +IClk7r7WHDEBAnGfer1bm+6pqzEafK4JWZl+hOYHrRi77nwhmzNjH4ClwXEleLQE +RQRI/d2Dmv5zDAbZZ/jXte2ZVjhZ/5heCtlMMf/Q2/Wtb6BHi1GTjGInBsbU6M+j +FLuTAUlnREIWg/c7HC1L0IpEk0nDEpRRrfQ4UI73vRFZ19IDhofRHlDmVjJHos1A +nXdwkUmZKO+mrvC7AgMBAAECggGAZxgO2LzFlYpI1Uxhwvo3SnJUpKsMr1vSSgyt +lBUeu5TYQcCea2LSGUjRqBEXglixX7ydRqTxRNuk4IcpJTE7fakSPaawQu5fhVhx +wZCYLm7DmGbl6YfWQnA52eFvN4CpJQ4CJc6A4KsICv7PlfqztJEoRxVtGff/xIKX +2OKez0K/RZleJZ5XVBXHD0lJqtYN+RiwSettd27NM6lLjaQ2XghG9jcZZiCCS8xI +TY5VEGx+gtup2VVg/24oarq86nNBmUMI0hihFHiSwNxSoFjurTRMFc8RIsP1xmsB +chHFG1DO+fZrRFDGTykKpvOOQrotg19R5Z2xTN51IHWA1f8gSXKhU0ABUbvmCABG +7EI0GaF+IVHREkDSoboGVO4t2yQWBdWa3TFZzpL6odoW1g7w0hb5CsaTApbwoYSH +9Otol7OSymAzEFi0GrWISgDldv0WxenXs+wHjig8l/hUiYM2FwI6xfoY++ROrv5N +wQSQjCh/pNkKtNR13d/vp0bN7LsRAoHBANXqNqbROX65St52p2QFtblnCwhw94WX +XkKFrvWFUa5BUEOmBRsokYDG1Ejji2y3M7bqYv2/4YHITmA6FA/JRht+SpTbpDcO +DkZ/uZYB3Zn7boNLIVYNlIYQlWcRT4SDMlaSocgS/X7U1sQdkiMkcTSclFrYpcbS +uqFQIfl0o7lZitcZ+1c+7tnWhzTzVjkDIbUQnZ5ehe+3seY6bEaiM/fPO+PIPJmD +noBPs5IzfPSMw81zpp6K5ZjWKfFnhf5VxwKBwQDFZUrw078KzTC4FRy/1GFNUnV7 ++374C7QYjwUSP81FzOasZHiZtUz5WoDk7OgNcdKwDq9wzM5SCK1RxXvMIdojQ/K0 +j2j6CBguZtkE2DzcqoBH/lbXAMYgTPlPDy6lQ1+cBRLTaphxJrupcd4dpInJjdCP +qkMJZgXa9iwn393ahNwv9uSfDWTSx7cAkD0f+iaCJPUbQtv8gUnIHN70c9r1ttxO +8wzIIQF4lRswbgxE2pni64jiiwff9A5dM7lsPW0CgcA2CGUpmeu18MPSkZISl8ah +QD/vL+m12tg6YV4iMjzBkUne7I6Zn5OxPYfdqWxMV7I+X9IFWnRxvdDeSY1Lt4F3 +7FyWYSyHo4tDj3unQm1hObteepm/DMsZWhMC58J4LFOIvp0S1oklgkmFXBrPuaLG +sU5f7B1jrVLq8DzEsAuzA8UoNC3iicj3SVL5swVyfTIEdF/74tBeDr5m8xbqeIo7 +7CKqlKeJ908QBhHLuimz8+J5gN7zaQ45ns8VZRrZdX0CgcEAgA9f1+83UHWf9rzT +ouvQmE43o9CcJaXKF+ipHJbBwFfXCt4/k1CYeGirmFjNjvNaARf135MvAu9YKeWZ +k+weTaSmT1d+fj9EGM5mWotCqNIAWTR6+A8regcV8qFh3Jth6qEr/nZ4IWhEdQBH +XinBrj3PoXSO+wjyIYR1TwUgM8Tf6EofNcSFtW0Vn0r0LMunseTyEtaES0NBL47z +YdnLon8YXCu4DhnDj4wEUJ4EXaMIFLxDWN1jpXn626JT2BHdAoHAaluO46HEwp0B +QEErdhpdaWcLs6Q4UTkpjz36MnLhPDJ5FPAaFSeJsbDrGNgrBNJVMD8lXLJJ5z5r +tO4YhJ0+u4UXYH5mY09pdL29uFZMUr9lqasFfnpXSfj2eAIDIgPlidAJObpDYha6 +HuN0iMPArbB0/r5MoLpiRLdUUHKDDnKoWWFd7zBYMm3WDLTGs2UaJcoypSkuY2TO +BSO4B6qOIRuy9tKtMRZDlIeKgYoP2z5ur0MQMK3ZTEX+zZa8YEhj +-----END RSA PRIVATE KEY----- diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..f139f16ef --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,2 @@ +github: ortuman +patreon: ortuman diff --git a/.github/gopher.png b/.github/gopher.png deleted file mode 100644 index 1d2ef674d..000000000 Binary files a/.github/gopher.png and /dev/null differ diff --git a/.travis.yml b/.travis.yml index 5b23d6228..1ab9f3d23 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,14 +3,13 @@ language: go env: global: - CC_TEST_REPORTER_ID=71266456503f8f3caaae69674d933342c3fce1e2cced98834a2820bfcd6ebe01 - - GO111MODULE=on os: - linux - osx go: - - 1.12.x + - 1.13.x before_script: - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-darwin-amd64 > ./cc-test-reporter; fi diff --git a/CHANGELOG.md b/CHANGELOG.md index c39b2f411..566a62e0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,99 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). -## [0.5.1] - 2019-05-03 +## [0.10.1] - 2020-03-22 ### Changed -- Implemented lock free actor run queue. - -## [0.5.0] - 2019-04-08 -### Added -- Offline message gateway support. - -## [0.4.11] - 2019-04-07 -### Fixed -- Fixed pgsql private storage. - -## [0.4.10] - 2019-03-17 -### Fixed -- Fallback to standard port on SRV resolution error. -- Use serialization intermediate buffer on socket send. - -## [0.4.9] - 2019-02-07 -### Fixed -- In-band registration bug. - -## [0.4.8] - 2019-01-23 -### Fixed -- S2S iq module routing. - -## [0.4.7] - 2019-01-22 -### Added -- SCRAM-SHA-512 authentication method. - -## [0.4.6] - 2019-01-19 -### Fixed -- Fixed Gajim client connecting issue. - -## [0.4.5] - 2019-01-16 -### Added -- PostgreSQL support. - -## [0.4.0] - 2019-01-01 -### Added -- Cluster mode support. šŸ„³ - -## [0.3.6] - 2018-12-15 -### Fixed -- Fixed bug in roster item deletion. - -## [0.3.5] - 2018-11-09 -### Fixed -- Fixed c2s and s2s message routing. - -## [0.3.4] - 2018-11-03 -### Added -- Built-in graceful shutdown support. - -## [0.3.3] - 2018-10-03 -### Changed -- New component interface. - -## [0.3.2] - 2018-09-04 -### Fixed -- Bug fixes. +- Set resource limit +## [0.9.0] - 2020-02-15 ### Changed -- New module interface. - -## [0.3.1] - 2018-07-17 -### Fixed -- IQ routing bug. - -## [0.3.0] - 2018-07-06 -### Added -- Added S2S support. - -### Removed -- Removed CGO dependency... thanks Sam Whited! šŸ˜‰ - -### Fixed -- crash: invalid XML parsing. - -## [0.2.0] - 2018-05-08 -### Added -- Added support for XEP-0191 (Blocking Command) -- Added support for XEP-0012 (Last Activity) -- Added support for XEP-0237 (Roster Versioning) -- RFC 7395: XMPP Subprotocol for WebSocket +- Router implementation refactoring -## [0.1.15] - 2018-03-20 -### Added -- Initial release (https://xmpp.org/rfcs/rfc3921.html) -- Added support for XEP-0030 (Service Discovery) -- Added support for XEP-0049 (Private XML Storage) -- Added support for XEP-0054 (vcard-temp) -- Added support for XEP-0077 (In-Band Registration) -- Added support for XEP-0092 (Software Version) -- Added support for XEP-0138 (Stream Compression) -- Added support for XEP-0160 (Best Practices for Handling Offline Messages) -- Added support for XEP-0199 (XMPP Ping) diff --git a/Makefile b/Makefile index 0d698660d..c8c655501 100644 --- a/Makefile +++ b/Makefile @@ -1,42 +1,67 @@ +.POSIX: +.SUFFIXES: + +GOFILES!=find . -name '*.go' + +GOLDFLAGS =-s -w -extldflags $(LDFLAGS) + +.PHONY: install install: - @GO111MODULE=on go install -ldflags="-s -w" + @go install -ldflags="-s -w" +.PHONY: install-tools install-tools: - @GO111MODULE=on go get -u \ + @env GO111MODULE=off go get -u \ golang.org/x/lint/golint \ golang.org/x/tools/cmd/goimports +.PHONY: fmt fmt: install-tools @echo "Checking go files format..." @GOIMP=$$(for f in $$(find . -type f -name "*.go" ! -path "./.cache/*" ! -path "./vendor/*" ! -name "bindata.go") ; do \ - goimports -l $$f ; \ - done) && echo $$GOIMP && test -z "$$GOIMP" + goimports -l $$f ; \ + done) && echo $$GOIMP && test -z "$$GOIMP" -build: +go.sum: $(GOFILES) go.mod + go mod tidy + +jackal: $(GOFILES) go.mod go.sum @echo "Building binary..." - @GO111MODULE=on go build -ldflags="-s -w" + @go build\ + -trimpath \ + -o $@ \ + -ldflags "$(GOLDFLAGS)" + +.PHONY: build +build: jackal +.PHONY: test test: @echo "Running tests..." @GO111MODULE=on go test -race $$(go list ./...) +.PHONY: coverate coverage: @echo "Generating coverage profile..." @go test -race -coverprofile=coverage.txt -covermode=atomic $$(go list ./...) +.PHONY: vet vet: @echo "Searching for buggy code..." @go vet $$(go list ./...) +.PHONY: lint lint: install-tools @echo "Running linter..." @golint $$(go list ./...) +.PHONY: dockerimage dockerimage: @echo "Building binary..." - @env GO111MODULE=on GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags="-s -w" + @env GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags="-s -w" @echo "Building docker image..." @docker build -f dockerfiles/Dockerfile -t ortuman/jackal . +.PHONY: clean clean: @go clean diff --git a/README.md b/README.md index 4f5a51c98..0b2fa4a88 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ An XMPP server written in Go. This repository is a fork of [ortuman/jackal](https://github.com/ortuman/jackal) making it available for SCION/QUIC. Refer to the original repository for general usage. -If you have go1.12 installed (not supporting go1.11 at the moment), you can build jackal using Makefile. +If you have go1.13 installed (or go1.14, not tested with more up to date versions), you can build jackal using Makefile. ```shell make build ``` @@ -13,6 +13,7 @@ You can check if the project has built successfully by running the following com ```shell ./jackal -h ``` + ## Running jackal In order to run jackal, you have to specify the configuration in .yml file. An example .yml file is provided in the repository as example.jackal.yml. You need to do the following steps before you can run the server with the configuration specified in the example.jackal.yml. @@ -33,8 +34,9 @@ mysql_secure_installation ``` Grant right to a dedicated 'jackal' user (replace `password` with your desired password). -```shell -echo "GRANT ALL ON jackal.* TO 'jackal'@'localhost' IDENTIFIED BY 'password';" | mysql -h localhost -u root -p +```sh +echo "CREATE USER IF NOT EXISTS 'jackal'@'localhost' IDENTIFIED BY 'password';" | mysql -h localhost -u root -p +echo "GRANT ALL ON jackal.* TO 'jackal'@'localhost';" | mysql -h localhost -u root -p ``` Create 'jackal' database (using previously created password). diff --git a/app/app.go b/app/app.go index f1d3b8551..8c73d702c 100644 --- a/app/app.go +++ b/app/app.go @@ -16,23 +16,31 @@ import ( "os" "os/signal" "path/filepath" + "runtime" "strconv" "syscall" "time" + "github.com/google/uuid" "github.com/ortuman/jackal/c2s" - "github.com/ortuman/jackal/cluster" + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/component" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/router/host" "github.com/ortuman/jackal/s2s" + s2srouter "github.com/ortuman/jackal/s2s/router" "github.com/ortuman/jackal/storage" "github.com/ortuman/jackal/version" "github.com/pkg/errors" ) const ( + envAllocationID = "JACKAL_ALLOCATION_ID" + + darwinOpenMax = 10240 + defaultShutDownWaitTime = time.Duration(5) * time.Second ) @@ -60,11 +68,10 @@ type Application struct { output io.Writer args []string logger log.Logger - storage storage.Storage - cluster *cluster.Cluster - router *router.Router + router router.Router mods *module.Modules comps *component.Components + s2sOutProvider *s2s.OutProvider s2s *s2s.S2S c2s *c2s.C2S debugSrv *http.Server @@ -100,11 +107,11 @@ func (a *Application) Run() error { fs.StringVar(&configFile, "c", "/etc/jackal/jackal.yml", "Configuration file path.") fs.Usage = func() { for i := range logoStr { - fmt.Fprintf(a.output, "%s\n", logoStr[i]) + _, _ = fmt.Fprintf(a.output, "%s\n", logoStr[i]) } - fmt.Fprintf(a.output, "%s\n", usageStr) + _, _ = fmt.Fprintf(a.output, "%s\n", usageStr) } - fs.Parse(a.args[1:]) + _ = fs.Parse(a.args[1:]) // print usage if showUsage { @@ -113,7 +120,7 @@ func (a *Application) Run() error { } // print version if showVersion { - fmt.Fprintf(a.output, "jackal version: %v\n", version.ApplicationVersion) + _, _ = fmt.Fprintf(a.output, "jackal version: %v\n", version.ApplicationVersion) return nil } // load configuration @@ -122,62 +129,70 @@ func (a *Application) Run() error { if err != nil { return err } + // create PID file if err := a.createPIDFile(cfg.PIDFile); err != nil { return err } - // initialize logger err = a.initLogger(&cfg.Logger, a.output) if err != nil { return err } + // set allocation identifier + allocID := os.Getenv(envAllocationID) + if len(allocID) == 0 { + allocID = uuid.New().String() + } + // show jackal's fancy logo - a.printLogo() + a.printLogo(allocID) // initialize storage - err = a.initStorage(&cfg.Storage) + repContainer, err := storage.New(&cfg.Storage) if err != nil { return err } + if err := repContainer.Presences().ClearPresences(context.Background()); err != nil { + return err + } - // initialize router - a.router, err = router.New(&cfg.Router) + // initialize hosts + hosts, err := host.New(cfg.Hosts) if err != nil { return err } + // initialize router + var s2sRouter router.S2SRouter - // initialize cluster - if cfg.Cluster != nil { - if storage.IsClusterCompatible() { - a.cluster, err = cluster.New(cfg.Cluster, a.router.ClusterDelegate()) - if err != nil { - return err - } - if a.cluster != nil { - a.router.SetCluster(a.cluster) - if err := a.cluster.Join(); err != nil { - log.Warnf("%v", err) - } - } - } else { - log.Warnf("cluster mode disabled: storage type '%s' is not compatible", cfg.Storage.Type) - } + if cfg.S2S != nil { + a.s2sOutProvider = s2s.NewOutProvider(cfg.S2S, hosts) + s2sRouter = s2srouter.New(a.s2sOutProvider) + } + a.router, err = router.New( + hosts, + c2srouter.New(repContainer.User(), repContainer.BlockList()), + s2sRouter, + ) + if err != nil { + return err } // initialize modules & components... - a.mods = module.New(&cfg.Modules, a.router) + a.mods = module.New(&cfg.Modules, a.router, repContainer, allocID) a.comps = component.New(&cfg.Components, a.mods.DiscoInfo) // start serving s2s... - a.s2s = s2s.New(cfg.S2S, a.mods, a.router) - if a.s2s != nil { - a.router.SetOutS2SProvider(a.s2s) + if err := a.setRLimit(); err != nil { + return err + } + if cfg.S2S != nil { + a.s2s = s2s.New(cfg.S2S, a.mods, a.s2sOutProvider, a.router) a.s2s.Start() } // start serving c2s... - a.c2s, err = c2s.New(cfg.C2S, a.mods, a.comps, a.router) + a.c2s, err = c2s.New(cfg.C2S, a.mods, a.comps, a.router, repContainer.User(), repContainer.BlockList()) if err != nil { return err } @@ -198,7 +213,7 @@ func (a *Application) Run() error { } func (a *Application) showVersion() { - fmt.Fprintf(a.output, "jackal version: %v\n", version.ApplicationVersion) + _, _ = fmt.Fprintf(a.output, "jackal version: %v\n", version.ApplicationVersion) } func (a *Application) createPIDFile(pidFile string) error { @@ -212,7 +227,7 @@ func (a *Application) createPIDFile(pidFile string) error { if err != nil { return err } - defer file.Close() + defer func() { _ = file.Close() }() currentPid := os.Getpid() if _, err := file.WriteString(strconv.FormatInt(int64(currentPid), 10)); err != nil { @@ -243,22 +258,32 @@ func (a *Application) initLogger(config *loggerConfig, output io.Writer) error { return nil } -func (a *Application) initStorage(config *storage.Config) error { - s, err := storage.New(config) - if err != nil { - return err - } - a.storage = s - storage.Set(a.storage) - return nil -} - -func (a *Application) printLogo() { +func (a *Application) printLogo(allocID string) { for i := range logoStr { log.Infof("%s", logoStr[i]) } log.Infof("") - log.Infof("jackal %v\n", version.ApplicationVersion) + log.Infof("jackal %v - allocation_id: %s\n", version.ApplicationVersion, allocID) +} + +func (a *Application) setRLimit() error { + var rLim syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLim); err != nil { + return err + } + if rLim.Cur < rLim.Max { + switch runtime.GOOS { + case "darwin": + // The max file limit is 10240, even though + // the max returned by Getrlimit is 1<<63-1. + // This is OPEN_MAX in sys/syslimits.h. + rLim.Cur = darwinOpenMax + default: + rLim.Cur = rLim.Max + } + return syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLim) + } + return nil } func (a *Application) initDebugServer(port int) error { @@ -267,7 +292,7 @@ func (a *Application) initDebugServer(port int) error { if err != nil { return err } - go a.debugSrv.Serve(ln) + go func() { _ = a.debugSrv.Serve(ln) }() log.Infof("debug server listening at %d...", port) return nil } @@ -293,22 +318,34 @@ func (a *Application) gracefullyShutdown() error { func (a *Application) shutdown(ctx context.Context) <-chan bool { c := make(chan bool, 1) go func() { - if a.debugSrv != nil { - a.debugSrv.Shutdown(ctx) - } - a.c2s.Shutdown(ctx) - if a.s2s != nil { - a.s2s.Shutdown(ctx) + if err := a.doShutdown(ctx); err != nil { + log.Warnf("failed to shutdown: %s", err) } - if a.cluster != nil { - a.cluster.Shutdown() - } - a.comps.Shutdown(ctx) - a.mods.Shutdown(ctx) - - storage.Unset() - log.Unset() c <- true }() return c } + +func (a *Application) doShutdown(ctx context.Context) error { + if a.debugSrv != nil { + if err := a.debugSrv.Shutdown(ctx); err != nil { + return err + } + } + a.c2s.Shutdown(ctx) + + if err := a.comps.Shutdown(ctx); err != nil { + return err + } + if err := a.mods.Shutdown(ctx); err != nil { + return err + } + + if outProvider := a.s2sOutProvider; outProvider != nil { + if err := outProvider.Shutdown(ctx); err != nil { + return err + } + } + log.Unset() + return nil +} diff --git a/app/config.go b/app/config.go index d8a854495..7901a1e94 100644 --- a/app/config.go +++ b/app/config.go @@ -9,15 +9,13 @@ import ( "bytes" "io/ioutil" - "github.com/ortuman/jackal/cluster" - "github.com/ortuman/jackal/c2s" "github.com/ortuman/jackal/component" "github.com/ortuman/jackal/module" - "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/router/host" "github.com/ortuman/jackal/s2s" "github.com/ortuman/jackal/storage" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) // debugConfig represents debug server configuration. @@ -36,16 +34,14 @@ type Config struct { Debug debugConfig `yaml:"debug"` Logger loggerConfig `yaml:"logger"` Storage storage.Config `yaml:"storage"` - Cluster *cluster.Config `yaml:"cluster"` - Router router.Config `yaml:"router"` + Hosts []host.Config `yaml:"hosts"` Modules module.Config `yaml:"modules"` Components component.Config `yaml:"components"` C2S []c2s.Config `yaml:"c2s"` S2S *s2s.Config `yaml:"s2s"` } -// FromFile loads default global configuration from -// a specified file. +// FromFile loads default global configuration from a specified file. func (cfg *Config) FromFile(configFile string) error { b, err := ioutil.ReadFile(configFile) if err != nil { @@ -54,8 +50,7 @@ func (cfg *Config) FromFile(configFile string) error { return yaml.Unmarshal(b, cfg) } -// FromBuffer loads default global configuration from -// a specified byte buffer. +// FromBuffer loads default global configuration from a specified byte buffer. func (cfg *Config) FromBuffer(buf *bytes.Buffer) error { return yaml.Unmarshal(buf.Bytes(), cfg) } diff --git a/auth/auth.go b/auth/auth.go index e2e175a7a..7d784931d 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -5,7 +5,11 @@ package auth -import "github.com/ortuman/jackal/xmpp" +import ( + "context" + + "github.com/ortuman/jackal/xmpp" +) const saslNamespace = "urn:ietf:params:xml:ns:xmpp-sasl" @@ -27,7 +31,7 @@ type Authenticator interface { UsesChannelBinding() bool // ProcessElement process an incoming authenticator element. - ProcessElement(xmpp.XElement) error + ProcessElement(context.Context, xmpp.XElement) error // Reset resets authenticator internal state. Reset() diff --git a/auth/auth_test.go b/auth/auth_test.go index 178356c03..f6e9932d8 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -6,35 +6,30 @@ package auth import ( + "context" "testing" + "github.com/google/uuid" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp/jid" - "github.com/pborman/uuid" "github.com/stretchr/testify/require" ) -func authTestSetup(user *model.User) (*stream.MockC2S, *memstorage.Storage) { - s := memstorage.New() - storage.Set(s) +func authTestSetup(user *model.User) (*stream.MockC2S, *memorystorage.User) { + s := memorystorage.NewUser() - storage.InsertOrUpdateUser(user) + _ = s.UpsertUser(context.Background(), user) j, _ := jid.New("mariana", "localhost", "res", true) - testStm := stream.NewMockC2S(uuid.New(), j) + testStm := stream.NewMockC2S(uuid.New().String(), j) testStm.SetJID(j) return testStm, s } -func authTestTeardown() { - storage.Unset() -} - func TestAuthError(t *testing.T) { require.Equal(t, "incorrect-encoding", ErrSASLIncorrectEncoding.(*SASLError).Error()) require.Equal(t, "malformed-request", ErrSASLMalformedRequest.(*SASLError).Error()) diff --git a/auth/digest_md5.go b/auth/digest_md5.go deleted file mode 100644 index 04cace488..000000000 --- a/auth/digest_md5.go +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package auth - -import ( - "bytes" - "crypto/md5" - "encoding/base64" - "encoding/hex" - "fmt" - "strings" - - "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/stream" - "github.com/ortuman/jackal/util" - "github.com/ortuman/jackal/xmpp" -) - -type digestMD5State int - -const ( - startDigestMD5State digestMD5State = iota - challengedDigestMD5State - authenticatedDigestMD5State -) - -type digestMD5Parameters struct { - username string - realm string - nonce string - cnonce string - nc string - qop string - servType string - digestURI string - response string - charset string - authID string -} - -func (r *digestMD5Parameters) setParameter(p string) { - key, val := util.SplitKeyAndValue(p, '=') - - // strip value double quotes - val = strings.TrimPrefix(val, `"`) - val = strings.TrimSuffix(val, `"`) - - switch key { - case "username": - r.username = val - case "realm": - r.realm = val - case "nonce": - r.nonce = val - case "cnonce": - r.cnonce = val - case "nc": - r.nc = val - case "qop": - r.qop = val - case "serv-type": - r.servType = val - case "digest-uri": - r.digestURI = val - case "response": - r.response = val - case "charset": - r.charset = val - case "authzid": - r.authID = val - } -} - -// DigestMD5 represents a DIGEST-MD5 authenticator. -type DigestMD5 struct { - stm stream.C2S - state digestMD5State - username string - authenticated bool -} - -// NewDigestMD5 returns a new digest-md5 authenticator instance. -func NewDigestMD5(stm stream.C2S) *DigestMD5 { - return &DigestMD5{ - stm: stm, - state: startDigestMD5State, - } -} - -// Mechanism returns authenticator mechanism name. -func (d *DigestMD5) Mechanism() string { - return "DIGEST-MD5" -} - -// Username returns authenticated username in case -// authentication process has been completed. -func (d *DigestMD5) Username() string { - return d.username -} - -// Authenticated returns whether or not user has been authenticated. -func (d *DigestMD5) Authenticated() bool { - return d.authenticated -} - -// UsesChannelBinding returns whether or not digest-md5 authenticator -// requires channel binding bytes. -func (d *DigestMD5) UsesChannelBinding() bool { - return false -} - -// ProcessElement process an incoming authenticator element. -func (d *DigestMD5) ProcessElement(elem xmpp.XElement) error { - if d.Authenticated() { - return nil - } - switch elem.Name() { - case "auth": - switch d.state { - case startDigestMD5State: - return d.handleStart(elem) - } - case "response": - switch d.state { - case challengedDigestMD5State: - return d.handleChallenged(elem) - case authenticatedDigestMD5State: - return d.handleAuthenticated(elem) - } - } - return ErrSASLNotAuthorized -} - -// Reset resets digest-md5 authenticator internal state. -func (d *DigestMD5) Reset() { - d.state = startDigestMD5State - d.username = "" - d.authenticated = false -} - -func (d *DigestMD5) handleStart(elem xmpp.XElement) error { - domain := d.stm.Domain() - nonce := base64.StdEncoding.EncodeToString(util.RandomBytes(32)) - chnge := fmt.Sprintf(`realm="%s",nonce="%s",qop="auth",charset=utf-8,algorithm=md5-sess`, domain, nonce) - - respElem := xmpp.NewElementNamespace("challenge", saslNamespace) - respElem.SetText(base64.StdEncoding.EncodeToString([]byte(chnge))) - d.stm.SendElement(respElem) - - d.state = challengedDigestMD5State - return nil -} - -func (d *DigestMD5) handleChallenged(elem xmpp.XElement) error { - if len(elem.Text()) == 0 { - return ErrSASLMalformedRequest - } - b, err := base64.StdEncoding.DecodeString(elem.Text()) - if err != nil { - return ErrSASLIncorrectEncoding - } - params := d.parseParameters(string(b)) - - // validate realm - if params.realm != d.stm.Domain() { - return ErrSASLNotAuthorized - } - // validate nc - if params.nc != "00000001" { - return ErrSASLNotAuthorized - } - // validate qop - if params.qop != "auth" { - return ErrSASLNotAuthorized - } - // validate serv-type - if len(params.servType) > 0 && params.servType != "xmpp" { - return ErrSASLNotAuthorized - } - // validate digest-uri - if !strings.HasPrefix(params.digestURI, "xmpp/") || params.digestURI[5:] != d.stm.Domain() { - return ErrSASLNotAuthorized - } - // validate user - user, err := storage.FetchUser(params.username) - if err != nil { - return err - } - if user == nil { - return ErrSASLNotAuthorized - } - // validate response - clientResp := d.computeResponse(params, user, true) - if clientResp != params.response { - return ErrSASLNotAuthorized - } - - // authenticated... compute and send server response - serverResp := d.computeResponse(params, user, false) - respAuth := fmt.Sprintf("rspauth=%s", serverResp) - - respElem := xmpp.NewElementNamespace("challenge", saslNamespace) - respElem.SetText(base64.StdEncoding.EncodeToString([]byte(respAuth))) - d.stm.SendElement(respElem) - - d.username = user.Username - d.state = authenticatedDigestMD5State - return nil -} - -func (d *DigestMD5) handleAuthenticated(elem xmpp.XElement) error { - d.authenticated = true - d.stm.SendElement(xmpp.NewElementNamespace("success", saslNamespace)) - return nil -} - -func (d *DigestMD5) parseParameters(str string) *digestMD5Parameters { - params := &digestMD5Parameters{} - s := strings.Split(str, ",") - for i := 0; i < len(s); i++ { - params.setParameter(s[i]) - } - return params -} - -func (d *DigestMD5) computeResponse(params *digestMD5Parameters, user *model.User, asClient bool) string { - x := params.username + ":" + params.realm + ":" + user.Password - y := d.md5Hash([]byte(x)) - - a1 := bytes.NewBuffer(y) - a1.WriteString(":" + params.nonce + ":" + params.cnonce) - if len(params.authID) > 0 { - a1.WriteString(":" + params.authID) - } - - var c string - if asClient { - c = "AUTHENTICATE" - } else { - c = "" - } - a2 := bytes.NewBuffer([]byte(c)) - a2.WriteString(":" + params.digestURI) - - ha1 := hex.EncodeToString(d.md5Hash(a1.Bytes())) - ha2 := hex.EncodeToString(d.md5Hash(a2.Bytes())) - - kd := ha1 - kd += ":" + params.nonce - kd += ":" + params.nc - kd += ":" + params.cnonce - kd += ":" + params.qop - kd += ":" + ha2 - return hex.EncodeToString(d.md5Hash([]byte(kd))) -} - -func (d *DigestMD5) md5Hash(b []byte) []byte { - hasher := md5.New() - hasher.Write(b) - return hasher.Sum(nil) -} diff --git a/auth/digest_md5_test.go b/auth/digest_md5_test.go deleted file mode 100644 index df32bb70b..000000000 --- a/auth/digest_md5_test.go +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package auth - -import ( - "encoding/base64" - "encoding/hex" - "fmt" - "testing" - - "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/storage/memstorage" - "github.com/ortuman/jackal/stream" - "github.com/ortuman/jackal/util" - "github.com/ortuman/jackal/xmpp" - "github.com/stretchr/testify/require" -) - -type digestMD5AuthTestHelper struct { - t *testing.T - testStrm stream.C2S - authr *DigestMD5 -} - -func (h *digestMD5AuthTestHelper) clientParamsFromChallenge(challenge string) *digestMD5Parameters { - b, err := base64.StdEncoding.DecodeString(challenge) - require.Nil(h.t, err) - srvParams := h.authr.parseParameters(string(b)) - clParams := *srvParams - clParams.setParameter("cnonce=" + hex.EncodeToString(util.RandomBytes(16))) - clParams.setParameter("username=" + h.testStrm.Username()) - clParams.setParameter("realm=" + h.testStrm.Domain()) - clParams.setParameter("nc=00000001") - clParams.setParameter("qop=auth") - clParams.setParameter("digest-uri=" + fmt.Sprintf("xmpp/%s", h.testStrm.Domain())) - clParams.setParameter("charset=utf-8") - clParams.setParameter("authzid=test") - return &clParams -} - -func (h *digestMD5AuthTestHelper) sendClientParamsResponse(params *digestMD5Parameters) error { - response := xmpp.NewElementNamespace("response", "urn:ietf:params:xml:ns:xmpp-sasl") - response.SetText(h.serializeParams(params)) - return h.authr.ProcessElement(response) -} - -func (h *digestMD5AuthTestHelper) serializeParams(params *digestMD5Parameters) string { - fmtStr := `username="%s",realm="%s",nonce="%s",cnonce="%s",nc=%s,qop=%s,digest-uri="%s",response=%s,charset=%s` - str := fmt.Sprintf(fmtStr, params.username, params.realm, params.nonce, params.cnonce, params.nc, params.qop, - params.digestURI, params.response, params.charset) - if len(params.servType) > 0 { - str += ",serv-type=" + params.servType - } - if len(params.authID) > 0 { - str += ",authzid=" + params.authID - } - return base64.StdEncoding.EncodeToString([]byte(str)) -} - -func TestDigesMD5Authentication(t *testing.T) { - user := &model.User{Username: "mariana", Password: "1234"} - testStm, s := authTestSetup(user) - defer authTestTeardown() - - authr := NewDigestMD5(testStm) - require.Equal(t, authr.Mechanism(), "DIGEST-MD5") - require.False(t, authr.UsesChannelBinding()) - - // test garbage input... - require.Equal(t, authr.ProcessElement(xmpp.NewElementName("garbage")), ErrSASLNotAuthorized) - - helper := digestMD5AuthTestHelper{t: t, testStrm: testStm, authr: authr} - - auth := xmpp.NewElementNamespace("auth", "urn:ietf:params:xml:ns:xmpp-sasl") - auth.SetAttribute("mechanism", "DIGEST-MD5") - authr.ProcessElement(auth) - - challenge := testStm.ReceiveElement() - require.Equal(t, challenge.Name(), "challenge") - clParams := helper.clientParamsFromChallenge(challenge.Text()) - clientResp := authr.computeResponse(clParams, user, true) - clParams.setParameter("response=" + clientResp) - clParams.response = clientResp - - // empty payload - response := xmpp.NewElementNamespace("response", "urn:ietf:params:xml:ns:xmpp-sasl") - response.SetText("") - require.Equal(t, ErrSASLMalformedRequest, authr.ProcessElement(response)) - - // incorrect payload encoding - response.SetText("bad_payload") - require.Equal(t, ErrSASLIncorrectEncoding, authr.ProcessElement(response)) - - // invalid username... - cl0 := *clParams - cl0.setParameter("username=mariana-inv") - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl0)) - - // invalid realm... - cl1 := *clParams - cl1.setParameter("realm=localhost-inv") - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl1)) - - // invalid nc... - cl2 := *clParams - cl2.setParameter("nc=00000001-inv") - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl2)) - - // invalid nc... - cl3 := *clParams - cl3.setParameter("qop=auth-inv") - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl3)) - - // invalid serv-type... - cl4 := *clParams - cl4.setParameter("serv-type=http") - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl4)) - - // invalid digest-uri... - cl5 := *clParams - cl5.setParameter("digest-uri=http/localhost") - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl5)) - - cl6 := *clParams - cl6.setParameter("digest-uri=xmpp/localhost-inv") - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl6)) - - // invalid password... - cl7 := *clParams - user2 := &model.User{Username: "mariana", Password: "bad_password"} - badClientResp := authr.computeResponse(&cl7, user2, true) - cl7.setParameter("response=" + badClientResp) - require.Equal(t, ErrSASLNotAuthorized, helper.sendClientParamsResponse(&cl7)) - - // storage error... - s.EnableMockedError() - require.Equal(t, memstorage.ErrMockedError, helper.sendClientParamsResponse(clParams)) - s.DisableMockedError() - - // successful authentication... - require.Nil(t, helper.sendClientParamsResponse(clParams)) - - challenge = testStm.ReceiveElement() - - serverResp := authr.computeResponse(clParams, user, false) - require.Equal(t, base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("rspauth=%s", serverResp))), challenge.Text()) - - response.SetText("") - authr.ProcessElement(response) - - success := testStm.ReceiveElement() - require.Equal(t, "success", success.Name()) - - // successfully authenticated - require.True(t, authr.Authenticated()) - require.Equal(t, "mariana", authr.Username()) - - // already authenticated... - require.Nil(t, authr.ProcessElement(auth)) - - // test reset - authr.Reset() - require.Equal(t, authr.state, startDigestMD5State) - require.False(t, authr.Authenticated()) - require.Equal(t, "", authr.Username()) -} diff --git a/auth/plain.go b/auth/plain.go index 2399be094..a9fa24c6b 100644 --- a/auth/plain.go +++ b/auth/plain.go @@ -7,9 +7,10 @@ package auth import ( "bytes" + "context" "encoding/base64" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" ) @@ -17,13 +18,14 @@ import ( // Plain represents a PLAIN authenticator. type Plain struct { stm stream.C2S + userRep repository.User username string authenticated bool } // NewPlain returns a new plain authenticator instance. -func NewPlain(stm stream.C2S) *Plain { - return &Plain{stm: stm} +func NewPlain(stm stream.C2S, userRep repository.User) *Plain { + return &Plain{stm: stm, userRep: userRep} } // Mechanism returns authenticator mechanism name. @@ -49,7 +51,7 @@ func (p *Plain) UsesChannelBinding() bool { } // ProcessElement process an incoming authenticator element. -func (p *Plain) ProcessElement(elem xmpp.XElement) error { +func (p *Plain) ProcessElement(ctx context.Context, elem xmpp.XElement) error { if p.authenticated { return nil } @@ -68,7 +70,7 @@ func (p *Plain) ProcessElement(elem xmpp.XElement) error { password := string(s[2]) // validate user and password - user, err := storage.FetchUser(username) + user, err := p.userRep.FetchUser(ctx, username) if err != nil { return err } @@ -78,7 +80,7 @@ func (p *Plain) ProcessElement(elem xmpp.XElement) error { p.username = username p.authenticated = true - p.stm.SendElement(xmpp.NewElementNamespace("success", saslNamespace)) + p.stm.SendElement(ctx, xmpp.NewElementNamespace("success", saslNamespace)) return nil } diff --git a/auth/plain_test.go b/auth/plain_test.go index 546bfe81a..c2db34f2a 100644 --- a/auth/plain_test.go +++ b/auth/plain_test.go @@ -7,11 +7,12 @@ package auth import ( "bytes" + "context" "encoding/base64" "testing" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/xmpp" "github.com/stretchr/testify/require" ) @@ -20,15 +21,14 @@ func TestAuthPlainAuthentication(t *testing.T) { var err error testStm, s := authTestSetup(&model.User{Username: "mariana", Password: "1234"}) - defer authTestTeardown() - authr := NewPlain(testStm) + authr := NewPlain(testStm, s) require.Equal(t, authr.Mechanism(), "PLAIN") require.False(t, authr.UsesChannelBinding()) elem := xmpp.NewElementNamespace("auth", "urn:ietf:params:xml:ns:xmpp-sasl") elem.SetAttribute("mechanism", "PLAIN") - authr.ProcessElement(elem) + _ = authr.ProcessElement(context.Background(), elem) buf := new(bytes.Buffer) buf.WriteByte(0) @@ -38,30 +38,30 @@ func TestAuthPlainAuthentication(t *testing.T) { elem.SetText(base64.StdEncoding.EncodeToString(buf.Bytes())) // storage error... - s.EnableMockedError() - require.Equal(t, authr.ProcessElement(elem), memstorage.ErrMockedError) - s.DisableMockedError() + memorystorage.EnableMockedError() + require.Equal(t, authr.ProcessElement(context.Background(), elem), memorystorage.ErrMocked) + memorystorage.DisableMockedError() // valid credentials... - err = authr.ProcessElement(elem) + err = authr.ProcessElement(context.Background(), elem) require.Nil(t, err) require.Equal(t, "mariana", authr.Username()) require.True(t, authr.Authenticated()) // already authenticated... - err = authr.ProcessElement(elem) + err = authr.ProcessElement(context.Background(), elem) require.Nil(t, err) // malformed request authr.Reset() elem.SetText("") - err = authr.ProcessElement(elem) + err = authr.ProcessElement(context.Background(), elem) require.Equal(t, ErrSASLMalformedRequest, err) // invalid payload authr.Reset() elem.SetText("bad formed base64") - err = authr.ProcessElement(elem) + err = authr.ProcessElement(context.Background(), elem) require.Equal(t, ErrSASLIncorrectEncoding, err) // invalid payload @@ -74,7 +74,7 @@ func TestAuthPlainAuthentication(t *testing.T) { elem.SetText(base64.StdEncoding.EncodeToString(buf.Bytes())) authr.Reset() - err = authr.ProcessElement(elem) + err = authr.ProcessElement(context.Background(), elem) require.Equal(t, ErrSASLIncorrectEncoding, err) // invalid user @@ -86,7 +86,7 @@ func TestAuthPlainAuthentication(t *testing.T) { elem.SetText(base64.StdEncoding.EncodeToString(buf.Bytes())) authr.Reset() - err = authr.ProcessElement(elem) + err = authr.ProcessElement(context.Background(), elem) require.Equal(t, ErrSASLNotAuthorized, err) // incorrect password @@ -98,6 +98,6 @@ func TestAuthPlainAuthentication(t *testing.T) { elem.SetText(base64.StdEncoding.EncodeToString(buf.Bytes())) authr.Reset() - err = authr.ProcessElement(elem) + err = authr.ProcessElement(context.Background(), elem) require.Equal(t, ErrSASLNotAuthorized, err) } diff --git a/auth/scram.go b/auth/scram.go index dba97755b..d813d8f35 100644 --- a/auth/scram.go +++ b/auth/scram.go @@ -7,22 +7,23 @@ package auth import ( "bytes" + "context" "crypto/hmac" + "crypto/rand" "crypto/sha1" "crypto/sha256" - "crypto/sha512" "encoding/base64" "fmt" "hash" "strings" + "github.com/google/uuid" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/transport" - "github.com/ortuman/jackal/util" + utilstring "github.com/ortuman/jackal/util/string" "github.com/ortuman/jackal/xmpp" - "github.com/pborman/uuid" "golang.org/x/crypto/pbkdf2" ) @@ -35,9 +36,6 @@ const ( // ScramSHA256 represents SCRAM-SHA-256 authentication method. ScramSHA256 - - // ScramSHA512 represents SCRAM-SHA-512 authentication method. - ScramSHA512 ) const iterationsCount = 4096 @@ -84,6 +82,7 @@ func (s *scramParameters) String() string { // Scram represents a SCRAM authenticator. type Scram struct { stm stream.C2S + userRep repository.User tr transport.Transport tp ScramType usesCb bool @@ -99,13 +98,14 @@ type Scram struct { } // NewScram returns a new scram authenticator instance. -func NewScram(stm stream.C2S, tr transport.Transport, scramType ScramType, usesChannelBinding bool) *Scram { +func NewScram(stm stream.C2S, tr transport.Transport, scramType ScramType, usesChannelBinding bool, userRep repository.User) *Scram { s := &Scram{ - stm: stm, - tr: tr, - tp: scramType, - usesCb: usesChannelBinding, - state: startScramState, + stm: stm, + userRep: userRep, + tr: tr, + tp: scramType, + usesCb: usesChannelBinding, + state: startScramState, } switch s.tp { case ScramSHA1: @@ -114,9 +114,6 @@ func NewScram(stm stream.C2S, tr transport.Transport, scramType ScramType, usesC case ScramSHA256: s.h = sha256.New s.hKeyLen = sha256.Size - case ScramSHA512: - s.h = sha512.New - s.hKeyLen = sha512.Size } return s } @@ -135,12 +132,6 @@ func (s *Scram) Mechanism() string { return "SCRAM-SHA-256-PLUS" } return "SCRAM-SHA-256" - - case ScramSHA512: - if s.usesCb { - return "SCRAM-SHA-512-PLUS" - } - return "SCRAM-SHA-512" } return "" } @@ -166,18 +157,18 @@ func (s *Scram) UsesChannelBinding() bool { } // ProcessElement process an incoming authenticator element. -func (s *Scram) ProcessElement(elem xmpp.XElement) error { +func (s *Scram) ProcessElement(ctx context.Context, elem xmpp.XElement) error { if s.Authenticated() { return nil } switch elem.Name() { case "auth": if s.state == startScramState { - return s.handleStart(elem) + return s.handleStart(ctx, elem) } case "response": if s.state == challengedScramState { - return s.handleChallenged(elem) + return s.handleChallenged(ctx, elem) } } return ErrSASLNotAuthorized @@ -195,7 +186,7 @@ func (s *Scram) Reset() { s.firstMessage = "" } -func (s *Scram) handleStart(elem xmpp.XElement) error { +func (s *Scram) handleStart(ctx context.Context, elem xmpp.XElement) error { p, err := s.getElementPayload(elem) if err != nil { return err @@ -209,7 +200,7 @@ func (s *Scram) handleStart(elem xmpp.XElement) error { if len(username) == 0 || len(cNonce) == 0 { return ErrSASLMalformedRequest } - user, err := storage.FetchUser(username) + user, err := s.userRep.FetchUser(ctx, username) if err != nil { return err } @@ -218,20 +209,24 @@ func (s *Scram) handleStart(elem xmpp.XElement) error { } s.user = user - s.srvNonce = cNonce + "-" + uuid.New() - s.salt = util.RandomBytes(32) + s.srvNonce = cNonce + "-" + uuid.New().String() + s.salt = make([]byte, 32) + _, err = rand.Read(s.salt) + if err != nil { + return err + } sb64 := base64.StdEncoding.EncodeToString(s.salt) s.firstMessage = fmt.Sprintf("r=%s,s=%s,i=%d", s.srvNonce, sb64, iterationsCount) respElem := xmpp.NewElementNamespace("challenge", saslNamespace) respElem.SetText(base64.StdEncoding.EncodeToString([]byte(s.firstMessage))) - s.stm.SendElement(respElem) + s.stm.SendElement(ctx, respElem) s.state = challengedScramState return nil } -func (s *Scram) handleChallenged(elem xmpp.XElement) error { +func (s *Scram) handleChallenged(ctx context.Context, elem xmpp.XElement) error { p, err := s.getElementPayload(elem) if err != nil { return err @@ -261,7 +256,7 @@ func (s *Scram) handleChallenged(elem xmpp.XElement) error { respElem := xmpp.NewElementNamespace("success", saslNamespace) respElem.SetText(base64.StdEncoding.EncodeToString([]byte(v))) - s.stm.SendElement(respElem) + s.stm.SendElement(ctx, respElem) s.authenticated = true return nil @@ -287,12 +282,15 @@ func (s *Scram) parseParameters(str string) error { } gs2BindFlag := sp[0] + // https://tools.ietf.org/html/rfc5801#section-5 switch gs2BindFlag { - case "y": + case "p": + // Channel binding is supported and required. if !s.usesCb { return ErrSASLNotAuthorized } - case "n": + case "n", "y": + // Channel binding is not supported, or is supported but is not required. break default: if !strings.HasPrefix(gs2BindFlag, "p=") { @@ -307,14 +305,14 @@ func (s *Scram) parseParameters(str string) error { p.gs2Header = gs2BindFlag + "," + authzID + "," if len(authzID) > 0 { - key, val := util.SplitKeyAndValue(authzID, '=') + key, val := utilstring.SplitKeyAndValue(authzID, '=') if len(key) == 0 || key != "a" { return ErrSASLMalformedRequest } p.authzID = val } for i := 2; i < len(sp); i++ { - key, val := util.SplitKeyAndValue(sp[i], '=') + key, val := utilstring.SplitKeyAndValue(sp[i], '=') p.params = append(p.params, scramParameter{key, val}) } s.params = p diff --git a/auth/scram_test.go b/auth/scram_test.go index 2803fa248..4b5ef1252 100644 --- a/auth/scram_test.go +++ b/auth/scram_test.go @@ -7,28 +7,39 @@ package auth import ( "bytes" + "context" "crypto/hmac" "crypto/sha1" "crypto/sha256" - "crypto/sha512" "crypto/tls" "crypto/x509" "encoding/base64" "fmt" "hash" + "math/rand" "strconv" "strings" "testing" + "time" "github.com/ortuman/jackal/model" "github.com/ortuman/jackal/transport" "github.com/ortuman/jackal/transport/compress" - "github.com/ortuman/jackal/util" + utilstring "github.com/ortuman/jackal/util/string" "github.com/ortuman/jackal/xmpp" "github.com/stretchr/testify/require" "golang.org/x/crypto/pbkdf2" ) +func randomBytes(l int) []byte { + b := make([]byte, l) + _, err := rand.Read(b) + if err != nil { + panic(err) + } + return b +} + type fakeTransport struct { cbBytes []byte } @@ -38,6 +49,7 @@ func (ft *fakeTransport) Write(p []byte) (n int, err error) { return 0, ni func (ft *fakeTransport) Close() error { return nil } func (ft *fakeTransport) Type() transport.Type { return transport.Socket } func (ft *fakeTransport) Flush() error { return nil } +func (ft *fakeTransport) SetWriteDeadline(_ time.Time) error { return nil } func (ft *fakeTransport) WriteString(s string) (n int, err error) { return 0, nil } func (ft *fakeTransport) StartTLS(*tls.Config, bool) { return } func (ft *fakeTransport) EnableCompression(compress.Level) { return } @@ -87,22 +99,12 @@ var tt = []scramAuthTestCase{ r: "6d805d99-6dc3-4e5a-9a68-653856fc5129", password: "1234", }, - { - // SCRAM-SHA-512 - id: 3, - scramType: ScramSHA512, - usesCb: false, - gs2BindFlag: "n", - n: "ortuman", - r: "6d805d99-6dc3-4e5a-9a68-653856fc5129", - password: "1234", - }, { // SCRAM-SHA-1-PLUS - id: 4, + id: 3, scramType: ScramSHA1, usesCb: true, - cbBytes: util.RandomBytes(23), + cbBytes: randomBytes(23), gs2BindFlag: "p=tls-unique", authID: "a=jackal.im", n: "ortuman", @@ -111,22 +113,10 @@ var tt = []scramAuthTestCase{ }, { // SCRAM-SHA-256-PLUS - id: 5, + id: 4, scramType: ScramSHA256, usesCb: true, - cbBytes: util.RandomBytes(32), - gs2BindFlag: "p=tls-unique", - authID: "a=jackal.im", - n: "ortuman", - r: "d712875c-bd3b-4b41-801d-eb9c541d9884", - password: "1234", - }, - { - // SCRAM-SHA-512-PLUS - id: 6, - scramType: ScramSHA512, - usesCb: true, - cbBytes: util.RandomBytes(32), + cbBytes: randomBytes(32), gs2BindFlag: "p=tls-unique", authID: "a=jackal.im", n: "ortuman", @@ -137,7 +127,7 @@ var tt = []scramAuthTestCase{ // Fail cases { // invalid user - id: 7, + id: 5, scramType: ScramSHA1, usesCb: false, gs2BindFlag: "n", @@ -148,7 +138,7 @@ var tt = []scramAuthTestCase{ }, { // invalid password - id: 8, + id: 6, scramType: ScramSHA1, usesCb: false, gs2BindFlag: "n", @@ -159,7 +149,7 @@ var tt = []scramAuthTestCase{ }, { // not authorized gs2BindFlag - id: 9, + id: 7, scramType: ScramSHA1, usesCb: false, gs2BindFlag: "y", @@ -170,7 +160,7 @@ var tt = []scramAuthTestCase{ }, { // invalid authID - id: 10, + id: 8, scramType: ScramSHA1, usesCb: false, gs2BindFlag: "n", @@ -182,7 +172,7 @@ var tt = []scramAuthTestCase{ }, { // not matching gs2BindFlag - id: 11, + id: 9, scramType: ScramSHA1, usesCb: false, gs2BindFlag: "p=tls-unique", @@ -194,7 +184,7 @@ var tt = []scramAuthTestCase{ }, { // not matching gs2BindFlag - id: 12, + id: 10, scramType: ScramSHA1, usesCb: false, gs2BindFlag: "q=tls-unique", @@ -206,7 +196,7 @@ var tt = []scramAuthTestCase{ }, { // empty username - id: 13, + id: 11, scramType: ScramSHA1, usesCb: false, gs2BindFlag: "n", @@ -220,54 +210,44 @@ var tt = []scramAuthTestCase{ func TestScramMechanisms(t *testing.T) { testTr := &fakeTransport{} - testStm, _ := authTestSetup(&model.User{Username: "ortuman", Password: "1234"}) - defer authTestTeardown() + testStm, s := authTestSetup(&model.User{Username: "ortuman", Password: "1234"}) - authr := NewScram(testStm, testTr, ScramSHA1, false) + authr := NewScram(testStm, testTr, ScramSHA1, false, s) require.Equal(t, authr.Mechanism(), "SCRAM-SHA-1") require.False(t, authr.UsesChannelBinding()) - authr2 := NewScram(testStm, testTr, ScramSHA1, true) + authr2 := NewScram(testStm, testTr, ScramSHA1, true, s) require.Equal(t, authr2.Mechanism(), "SCRAM-SHA-1-PLUS") require.True(t, authr2.UsesChannelBinding()) - authr3 := NewScram(testStm, testTr, ScramSHA256, false) + authr3 := NewScram(testStm, testTr, ScramSHA256, false, s) require.Equal(t, authr3.Mechanism(), "SCRAM-SHA-256") require.False(t, authr3.UsesChannelBinding()) - authr4 := NewScram(testStm, testTr, ScramSHA256, true) + authr4 := NewScram(testStm, testTr, ScramSHA256, true, s) require.Equal(t, authr4.Mechanism(), "SCRAM-SHA-256-PLUS") require.True(t, authr4.UsesChannelBinding()) - authr5 := NewScram(testStm, testTr, ScramSHA512, false) - require.Equal(t, authr5.Mechanism(), "SCRAM-SHA-512") - require.False(t, authr5.UsesChannelBinding()) - - authr6 := NewScram(testStm, testTr, ScramSHA512, true) - require.Equal(t, authr6.Mechanism(), "SCRAM-SHA-512-PLUS") - require.True(t, authr6.UsesChannelBinding()) - - authr7 := NewScram(testStm, testTr, ScramType(99), true) - require.Equal(t, authr7.Mechanism(), "") + authr5 := NewScram(testStm, testTr, ScramType(99), true, s) + require.Equal(t, authr5.Mechanism(), "") } func TestScramBadPayload(t *testing.T) { testTr := &fakeTransport{} - testStm, _ := authTestSetup(&model.User{Username: "ortuman", Password: "1234"}) - defer authTestTeardown() + testStm, s := authTestSetup(&model.User{Username: "ortuman", Password: "1234"}) - authr := NewScram(testStm, testTr, ScramSHA1, false) + authr := NewScram(testStm, testTr, ScramSHA1, false, s) auth := xmpp.NewElementNamespace("auth", "urn:ietf:params:xml:ns:xmpp-sasl") auth.SetAttribute("mechanism", authr.Mechanism()) // empty auth payload - require.Equal(t, ErrSASLIncorrectEncoding, authr.ProcessElement(auth)) + require.Equal(t, ErrSASLIncorrectEncoding, authr.ProcessElement(context.Background(), auth)) // incorrect auth payload encoding authr.Reset() auth.SetText(".") - require.Equal(t, ErrSASLIncorrectEncoding, authr.ProcessElement(auth)) + require.Equal(t, ErrSASLIncorrectEncoding, authr.ProcessElement(context.Background(), auth)) } func TestScramTestCases(t *testing.T) { @@ -285,10 +265,9 @@ func processScramTestCase(t *testing.T, tc *scramAuthTestCase) error { if tc.usesCb { tr.cbBytes = tc.cbBytes } - testStm, _ := authTestSetup(&model.User{Username: "ortuman", Password: "1234"}) - defer authTestTeardown() + testStm, s := authTestSetup(&model.User{Username: "ortuman", Password: "1234"}) - authr := NewScram(testStm, tr, tc.scramType, tc.usesCb) + authr := NewScram(testStm, tr, tc.scramType, tc.usesCb, s) auth := xmpp.NewElementNamespace("auth", saslNamespace) auth.SetAttribute("mechanism", authr.Mechanism()) @@ -298,7 +277,7 @@ func processScramTestCase(t *testing.T, tc *scramAuthTestCase) error { authPayload := gs2Header + clientInitialMessage auth.SetText(base64.StdEncoding.EncodeToString([]byte(authPayload))) - err := authr.ProcessElement(auth) + err := authr.ProcessElement(context.Background(), auth) if err != nil { return err } @@ -329,7 +308,7 @@ func processScramTestCase(t *testing.T, tc *scramAuthTestCase) error { response := xmpp.NewElementNamespace("response", saslNamespace) response.SetText(base64.StdEncoding.EncodeToString([]byte(res.clientFinalMessage))) - err = authr.ProcessElement(response) + err = authr.ProcessElement(context.Background(), response) if err != nil { return err } @@ -344,7 +323,7 @@ func processScramTestCase(t *testing.T, tc *scramAuthTestCase) error { require.True(t, authr.Authenticated()) require.Equal(t, tc.n, authr.Username()) - require.Nil(t, authr.ProcessElement(auth)) // test already authenticated... + require.Nil(t, authr.ProcessElement(context.Background(), auth)) // test already authenticated... return nil } @@ -378,7 +357,7 @@ func parseScramResponse(b64 string) (map[string]string, error) { ret := map[string]string{} s1 := strings.Split(string(s), ",") for _, s0 := range s1 { - k, v := util.SplitKeyAndValue(s0, '=') + k, v := utilstring.SplitKeyAndValue(s0, '=') ret[k] = v } return ret, nil @@ -390,8 +369,6 @@ func testScramAuthPbkdf2(b []byte, salt []byte, scramType ScramType, iterationCo return pbkdf2.Key(b, salt, iterationCount, sha1.Size, sha1.New) case ScramSHA256: return pbkdf2.Key(b, salt, iterationCount, sha256.Size, sha256.New) - case ScramSHA512: - return pbkdf2.Key(b, salt, iterationCount, sha512.Size, sha512.New) } return nil } @@ -403,8 +380,6 @@ func testScramAuthHmac(b []byte, key []byte, scramType ScramType) []byte { h = sha1.New case ScramSHA256: h = sha256.New - case ScramSHA512: - h = sha512.New } m := hmac.New(h, key) m.Write(b) @@ -418,8 +393,6 @@ func testScramAuthHash(b []byte, scramType ScramType) []byte { h = sha1.New() case ScramSHA256: h = sha256.New() - case ScramSHA512: - h = sha512.New() } h.Write(b) return h.Sum(nil) diff --git a/c2s/c2s.go b/c2s/c2s.go index 744586688..ea78be5e5 100644 --- a/c2s/c2s.go +++ b/c2s/c2s.go @@ -14,6 +14,7 @@ import ( "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage/repository" "github.com/pkg/errors" ) @@ -32,9 +33,7 @@ type c2sServer interface { shutdown(ctx context.Context) error } -var createC2SServer = func(config *Config, mods *module.Modules, comps *component.Components, router *router.Router) c2sServer { - return &server{cfg: config, mods: mods, comps: comps, router: router} -} +var createC2SServer = newC2SServer // C2S represents a client-to-server connection manager. type C2S struct { @@ -44,13 +43,13 @@ type C2S struct { } // New returns a new instance of a c2s connection manager. -func New(configs []Config, mods *module.Modules, comps *component.Components, router *router.Router) (*C2S, error) { +func New(configs []Config, mods *module.Modules, comps *component.Components, router router.Router, userRep repository.User, blockListRep repository.BlockList) (*C2S, error) { if len(configs) == 0 { return nil, errors.New("at least one c2s configuration is required") } c := &C2S{servers: make(map[string]c2sServer)} for _, config := range configs { - srv := createC2SServer(&config, mods, comps, router) + srv := createC2SServer(&config, mods, comps, router, userRep, blockListRep) c.servers[config.ID] = srv } return c, nil diff --git a/c2s/c2s_test.go b/c2s/c2s_test.go index 7c0159972..1e211082f 100644 --- a/c2s/c2s_test.go +++ b/c2s/c2s_test.go @@ -15,11 +15,13 @@ import ( "testing" "time" + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/component" "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + "github.com/ortuman/jackal/router/host" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/xmpp" "github.com/stretchr/testify/require" ) @@ -47,8 +49,8 @@ func (frw *fakeSockReaderWriter) Read(b []byte) (n int, err error) { } func (frw *fakeSockReaderWriter) Close() error { - frw.w.Close() - frw.r.Close() + _ = frw.w.Close() + _ = frw.r.Close() return nil } @@ -92,8 +94,8 @@ func (c *fakeSocketConn) Write(b []byte) (n int, err error) { func (c *fakeSocketConn) Close() error { if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - c.wr.Close() - c.rd.Close() + _ = c.wr.Close() + _ = c.rd.Close() close(c.closeCh) return nil } @@ -135,7 +137,7 @@ func (c *fakeSocketConn) loop() { for { select { case b := <-c.wrCh: - c.wr.Write(b) + _, _ = c.wr.Write(b) case <-c.closeCh: return } @@ -152,15 +154,17 @@ var ( func (a fakeAddr) Network() string { return "net" } func (a fakeAddr) String() string { return "str" } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, repository.User, repository.BlockList) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + userRep := memorystorage.NewUser() + blockListRep := memorystorage.NewBlockList() + r, _ := router.New( + hosts, + c2srouter.New(userRep, blockListRep), + nil, + ) + return r, userRep, blockListRep } type fakeC2SServer struct { @@ -185,7 +189,7 @@ func (s *fakeC2SServer) shutdown(ctx context.Context) error { } func TestC2S_StartAndShutdown(t *testing.T) { - c2s, fakeSrv := setupTestC2S() + c2s, fakeSrv := setupTestC2S("localhost") c2s.Start() select { @@ -204,11 +208,22 @@ func TestC2S_StartAndShutdown(t *testing.T) { } } -func setupTestC2S() (*C2S, *fakeC2SServer) { +func setupTestC2S(domain string) (*C2S, *fakeC2SServer) { srv := newFakeC2SServer() - createC2SServer = func(_ *Config, _ *module.Modules, _ *component.Components, _ *router.Router) c2sServer { + createC2SServer = func(_ *Config, _ *module.Modules, _ *component.Components, _ router.Router, _ repository.User, _ repository.BlockList) c2sServer { return srv } - c2s, _ := New([]Config{{}}, &module.Modules{}, &component.Components{}, &router.Router{}) + + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + userRep := memorystorage.NewUser() + blockListRep := memorystorage.NewBlockList() + r, _ := router.New( + hosts, + c2srouter.New(userRep, blockListRep), + nil, + ) + + c2s, _ := New([]Config{{}}, &module.Modules{}, &component.Components{}, r, userRep, blockListRep) return c2s, srv } diff --git a/c2s/config.go b/c2s/config.go index 5bdbdbb28..78959d9ab 100644 --- a/c2s/config.go +++ b/c2s/config.go @@ -16,11 +16,12 @@ import ( ) const ( - defaultTransportConnectTimeout = time.Duration(5) * time.Second - defaultTransportMaxStanzaSize = 32768 - defaultTransportPort = 5222 - defaultTransportKeepAlive = time.Duration(120) * time.Second - defaultTransportURLPath = "/xmpp/ws" + defaultConnectTimeout = time.Duration(5) * time.Second + defaultTimeout = time.Duration(20) * time.Second + defaultMaxStanzaSize = 32768 + defaultTransportPort = 5222 + defaultTransportKeepAlive = time.Duration(120) * time.Second + defaultTransportURLPath = "/xmpp/ws" ) // ResourceConflictPolicy represents a resource conflict policy. @@ -72,7 +73,6 @@ type TransportConfig struct { Type transport.Type BindAddress string Port int - KeepAlive time.Duration URLPath string } @@ -95,9 +95,6 @@ func (t *TransportConfig) UnmarshalYAML(unmarshal func(interface{}) error) error case "", "socket": t.Type = transport.Socket - case "websocket": - t.Type = transport.WebSocket - default: return fmt.Errorf("c2s.TransportConfig: unrecognized transport type: %s", p.Type) } @@ -113,10 +110,6 @@ func (t *TransportConfig) UnmarshalYAML(unmarshal func(interface{}) error) error if t.Port == 0 { t.Port = defaultTransportPort } - t.KeepAlive = time.Duration(p.KeepAlive) * time.Second - if t.KeepAlive == 0 { - t.KeepAlive = defaultTransportKeepAlive - } return nil } @@ -130,6 +123,8 @@ type TLSConfig struct { type Config struct { ID string ConnectTimeout time.Duration + Timeout time.Duration + KeepAlive time.Duration MaxStanzaSize int ResourceConflict ResourceConflictPolicy Transport TransportConfig @@ -142,6 +137,8 @@ type configProxy struct { Domain string `yaml:"domain"` TLS TLSConfig `yaml:"tls"` ConnectTimeout int `yaml:"connect_timeout"` + Timeout int `yaml:"timeout"` + KeepAlive int `yaml:"keep_alive"` MaxStanzaSize int `yaml:"max_stanza_size"` ResourceConflict string `yaml:"resource_conflict"` Transport TransportConfig `yaml:"transport"` @@ -158,11 +155,19 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { cfg.ID = p.ID cfg.ConnectTimeout = time.Duration(p.ConnectTimeout) * time.Second if cfg.ConnectTimeout == 0 { - cfg.ConnectTimeout = defaultTransportConnectTimeout + cfg.ConnectTimeout = defaultConnectTimeout + } + cfg.Timeout = time.Duration(p.Timeout) * time.Second + if cfg.Timeout == 0 { + cfg.Timeout = defaultTimeout + } + cfg.KeepAlive = time.Duration(p.KeepAlive) * time.Second + if cfg.KeepAlive == 0 { + cfg.KeepAlive = defaultTransportKeepAlive } cfg.MaxStanzaSize = p.MaxStanzaSize if cfg.MaxStanzaSize == 0 { - cfg.MaxStanzaSize = defaultTransportMaxStanzaSize + cfg.MaxStanzaSize = defaultMaxStanzaSize } // validate resource conflict policy type @@ -193,8 +198,9 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { } type streamConfig struct { - transport transport.Transport connectTimeout time.Duration + timeout time.Duration + keepAlive time.Duration maxStanzaSize int resourceConflict ResourceConflictPolicy sasl []string diff --git a/c2s/config_test.go b/c2s/config_test.go index b764244db..65d92f0f5 100644 --- a/c2s/config_test.go +++ b/c2s/config_test.go @@ -8,7 +8,6 @@ package c2s import ( "os" "testing" - "time" "github.com/ortuman/jackal/transport" "github.com/ortuman/jackal/transport/compress" @@ -46,14 +45,6 @@ func TestTransportConfig(t *testing.T) { require.Equal(t, transport.Socket, s.Type) require.Equal(t, "0.0.0.0", s.BindAddress) require.Equal(t, 5222, s.Port) - require.Equal(t, time.Second*time.Duration(120), s.KeepAlive) - - err = yaml.Unmarshal([]byte("{type: websocket, url_path: /xmpp/ws}"), &s) - require.Nil(t, err) - - require.Equal(t, transport.WebSocket, s.Type) - require.Equal(t, 5222, s.Port) - require.Equal(t, time.Second*time.Duration(120), s.KeepAlive) } func TestConfig(t *testing.T) { diff --git a/c2s/in.go b/c2s/in.go index 08bcd9e60..9446261f1 100644 --- a/c2s/in.go +++ b/c2s/in.go @@ -6,27 +6,27 @@ package c2s import ( + "context" "crypto/tls" "sync" "sync/atomic" "time" - "github.com/ortuman/jackal/runqueue" - + "github.com/google/uuid" "github.com/ortuman/jackal/auth" - "github.com/ortuman/jackal/cluster" "github.com/ortuman/jackal/component" streamerror "github.com/ortuman/jackal/errors" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/router" "github.com/ortuman/jackal/session" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/transport" "github.com/ortuman/jackal/transport/compress" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" - "github.com/pborman/uuid" ) const ( @@ -40,42 +40,49 @@ const ( type inStream struct { cfg *streamConfig - router *router.Router + router router.Router + userRep repository.User + blockListRep repository.BlockList mods *module.Modules comps *component.Components sess *session.Session + tr transport.Transport + mu sync.RWMutex id string connectTm *time.Timer + readTimeoutTm *time.Timer state uint32 authenticators []auth.Authenticator activeAuth auth.Authenticator runQueue *runqueue.RunQueue - - mu sync.RWMutex - jid *jid.JID - secured bool - compressed bool - authenticated bool - sessStarted bool - presence *xmpp.Presence - - contextMu sync.RWMutex - context map[string]interface{} -} - -func newStream(id string, config *streamConfig, mods *module.Modules, comps *component.Components, router *router.Router) stream.C2S { + jid *jid.JID + secured bool + compressed bool + authenticated bool + sessStarted bool + presence *xmpp.Presence + ctx context.Context + ctxCancelFn context.CancelFunc +} + +func newStream(id string, config *streamConfig, tr transport.Transport, mods *module.Modules, comps *component.Components, router router.Router, userRep repository.User, blockListRep repository.BlockList) stream.C2S { + ctx, ctxCancelFn := context.WithCancel(context.Background()) s := &inStream{ - cfg: config, - router: router, - mods: mods, - comps: comps, - id: id, - context: make(map[string]interface{}), - runQueue: runqueue.New(id), + cfg: config, + tr: tr, + router: router, + userRep: userRep, + blockListRep: blockListRep, + mods: mods, + comps: comps, + id: id, + runQueue: runqueue.New(id), + ctx: ctx, + ctxCancelFn: ctxCancelFn, } // initialize stream context - secured := !(config.transport.Type() == transport.Socket) + secured := !(tr.Type() == transport.Socket) s.setSecured(secured) s.setJID(&jid.JID{}) @@ -98,79 +105,22 @@ func (s *inStream) ID() string { return s.id } -// Context returns a copy of the stream associated context. -func (s *inStream) Context() map[string]interface{} { - m := make(map[string]interface{}) - s.contextMu.RLock() - for k, v := range s.context { - m[k] = v - } - s.contextMu.RUnlock() - return m -} - -// SetString associates a string context value to a key. -func (s *inStream) SetString(key string, value string) { - s.setContextValue(key, value) -} - -// GetString returns the context value associated with the key as a string. -func (s *inStream) GetString(key string) string { - var ret string - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if s, ok := s.context[key].(string); ok { - ret = s - } - return ret -} - -// SetInt associates an integer context value to a key. -func (s *inStream) SetInt(key string, value int) { - s.setContextValue(key, value) -} - -// GetInt returns the context value associated with the key as an integer. -func (s *inStream) GetInt(key string) int { - var ret int - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if i, ok := s.context[key].(int); ok { - ret = i - } - return ret -} - -// SetFloat associates a float context value to a key. -func (s *inStream) SetFloat(key string, value float64) { - s.setContextValue(key, value) -} - -// GetFloat returns the context value associated with the key as a float64. -func (s *inStream) GetFloat(key string) float64 { - var ret float64 - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if f, ok := s.context[key].(float64); ok { - ret = f - } - return ret +func (s *inStream) Context() context.Context { + s.mu.RLock() + defer s.mu.RUnlock() + return s.ctx } -// SetBool associates a boolean context value to a key. -func (s *inStream) SetBool(key string, value bool) { - s.setContextValue(key, value) +func (s *inStream) Value(key interface{}) interface{} { + s.mu.RLock() + defer s.mu.RUnlock() + return s.ctx.Value(key) } -// GetBool returns the context value associated with the key as a boolean. -func (s *inStream) GetBool(key string) bool { - var ret bool - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if b, ok := s.context[key].(bool); ok { - ret = b - } - return ret +func (s *inStream) SetValue(key, value interface{}) { + s.mu.Lock() + defer s.mu.Unlock() + s.ctx = context.WithValue(s.ctx, key, value) } // Username returns current stream username. @@ -217,73 +167,74 @@ func (s *inStream) Presence() *xmpp.Presence { } // SendElement writes an XMPP element to the stream. -func (s *inStream) SendElement(elem xmpp.XElement) { +func (s *inStream) SendElement(ctx context.Context, elem xmpp.XElement) { if s.getState() == disconnected { return } - s.runQueue.Run(func() { s.writeElement(elem) }) + s.runQueue.Run(func() { s.writeElement(ctx, elem) }) } // Disconnect disconnects remote peer by closing the underlying TCP socket connection. -func (s *inStream) Disconnect(err error) { +func (s *inStream) Disconnect(ctx context.Context, err error) { if s.getState() == disconnected { return } waitCh := make(chan struct{}) s.runQueue.Run(func() { - s.disconnect(err) + s.disconnect(ctx, err) close(waitCh) }) <-waitCh } func (s *inStream) initializeAuthenticators() { - tr := s.cfg.transport + tr := s.tr + hasChannelBinding := len(tr.ChannelBindingBytes(transport.TLSUnique)) > 0 var authenticators []auth.Authenticator for _, a := range s.cfg.sasl { switch a { case "plain": - authenticators = append(authenticators, auth.NewPlain(s)) - - case "digest_md5": - authenticators = append(authenticators, auth.NewDigestMD5(s)) + authenticators = append(authenticators, auth.NewPlain(s, s.userRep)) case "scram_sha_1": - authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA1, false)) - authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA1, true)) + authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA1, false, s.userRep)) + if hasChannelBinding { + authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA1, true, s.userRep)) + } case "scram_sha_256": - authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA256, false)) - authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA256, true)) - - case "scram_sha_512": - authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA512, false)) - authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA512, true)) + authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA256, false, s.userRep)) + if hasChannelBinding { + authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA256, true, s.userRep)) + } } } s.authenticators = authenticators } func (s *inStream) connectTimeout() { - s.runQueue.Run(func() { s.disconnect(streamerror.ErrConnectionTimeout) }) + s.runQueue.Run(func() { + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) + s.disconnect(ctx, streamerror.ErrConnectionTimeout) + }) } -func (s *inStream) handleElement(elem xmpp.XElement) { +func (s *inStream) handleElement(ctx context.Context, elem xmpp.XElement) { switch s.getState() { case connecting: - s.handleConnecting(elem) + s.handleConnecting(ctx, elem) case connected: - s.handleConnected(elem) + s.handleConnected(ctx, elem) case authenticated: - s.handleAuthenticated(elem) + s.handleAuthenticated(ctx, elem) case authenticating: - s.handleAuthenticating(elem) + s.handleAuthenticating(ctx, elem) case bound: - s.handleBound(elem) + s.handleBound(ctx, elem) } } -func (s *inStream) handleConnecting(elem xmpp.XElement) { +func (s *inStream) handleConnecting(ctx context.Context, elem xmpp.XElement) { // cancel connection timeout timer if s.connectTm != nil { s.connectTm.Stop() @@ -309,13 +260,13 @@ func (s *inStream) handleConnecting(elem xmpp.XElement) { features.AppendElements(s.authenticatedFeatures()) s.setState(authenticated) } - _ = s.sess.Open(features) + _ = s.sess.Open(ctx, features) } func (s *inStream) unauthenticatedFeatures() []xmpp.XElement { var features []xmpp.XElement - isSocketTr := s.cfg.transport.Type() == transport.Socket + isSocketTr := s.tr.Type() == transport.Socket if isSocketTr && !s.IsSecured() { startTLS := xmpp.NewElementName("starttls") @@ -351,7 +302,7 @@ func (s *inStream) unauthenticatedFeatures() []xmpp.XElement { func (s *inStream) authenticatedFeatures() []xmpp.XElement { var features []xmpp.XElement - isSocketTr := s.cfg.transport.Type() == transport.Socket + isSocketTr := s.tr.Type() == transport.Socket // attach compression feature compressionAvailable := isSocketTr && s.cfg.compression.Level != compress.NoCompression @@ -378,80 +329,80 @@ func (s *inStream) authenticatedFeatures() []xmpp.XElement { return features } -func (s *inStream) handleConnected(elem xmpp.XElement) { +func (s *inStream) handleConnected(ctx context.Context, elem xmpp.XElement) { switch elem.Name() { case "starttls": - s.proceedStartTLS(elem) + s.proceedStartTLS(ctx, elem) case "auth": - s.startAuthentication(elem) + s.startAuthentication(ctx, elem) case "iq": iq := elem.(*xmpp.IQ) if reg := s.mods.Register; reg != nil && reg.MatchesIQ(iq) { if s.IsSecured() { - reg.ProcessIQWithStream(iq, s) + reg.ProcessIQWithStream(ctx, iq, s) } else { // channel isn't safe enough to enable a password change - s.writeElement(iq.NotAuthorizedError()) + s.writeElement(ctx, iq.NotAuthorizedError()) } return } else if iq.Elements().ChildNamespace("query", "jabber:iq:auth") != nil { // don't allow non-SASL authentication - s.writeElement(iq.ServiceUnavailableError()) + s.writeElement(ctx, iq.ServiceUnavailableError()) return } fallthrough case "message", "presence": - s.disconnectWithStreamError(streamerror.ErrNotAuthorized) + s.disconnectWithStreamError(ctx, streamerror.ErrNotAuthorized) default: - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) } } -func (s *inStream) handleAuthenticating(elem xmpp.XElement) { +func (s *inStream) handleAuthenticating(ctx context.Context, elem xmpp.XElement) { if elem.Namespace() != saslNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidNamespace) return } ath := s.activeAuth - _ = s.continueAuthentication(elem, ath) + _ = s.continueAuthentication(ctx, elem, ath) if ath.Authenticated() { - s.finishAuthentication(ath.Username()) + s.finishAuthentication(ctx, ath.Username()) } } -func (s *inStream) handleAuthenticated(elem xmpp.XElement) { +func (s *inStream) handleAuthenticated(ctx context.Context, elem xmpp.XElement) { switch elem.Name() { case "compress": if elem.Namespace() != compressProtocolNamespace { - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) return } - s.compress(elem) + s.compress(ctx, elem) case "iq": iq := elem.(*xmpp.IQ) if len(s.JID().Resource()) == 0 { // Expecting bind - s.bindResource(iq) + s.bindResource(ctx, iq) } default: - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) } } -func (s *inStream) handleBound(elem xmpp.XElement) { +func (s *inStream) handleBound(ctx context.Context, elem xmpp.XElement) { // reset ping timer deadline if p := s.mods.Ping; p != nil { p.SchedulePing(s) } stanza, ok := elem.(xmpp.Stanza) if !ok { - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) return } // handle session IQ @@ -459,9 +410,9 @@ func (s *inStream) handleBound(elem xmpp.XElement) { if iq.Elements().ChildNamespace("session", sessionNamespace) != nil { if !s.isSessionStarted() { s.setSessionStarted(true) - s.writeElement(iq.ResultIQ()) + s.writeElement(ctx, iq.ResultIQ()) } else { - s.writeElement(iq.NotAllowedError()) + s.writeElement(ctx, iq.NotAllowedError()) } return } @@ -470,56 +421,56 @@ func (s *inStream) handleBound(elem xmpp.XElement) { switch stanza := stanza.(type) { case *xmpp.IQ: if di := s.mods.DiscoInfo; di != nil && di.MatchesIQ(stanza) { - di.ProcessIQ(stanza) + di.ProcessIQ(ctx, stanza) return } break } - comp.ProcessStanza(stanza, s) + comp.ProcessStanza(ctx, stanza, s) return } - s.processStanza(stanza) + s.processStanza(ctx, stanza) } -func (s *inStream) proceedStartTLS(elem xmpp.XElement) { +func (s *inStream) proceedStartTLS(ctx context.Context, elem xmpp.XElement) { if s.IsSecured() { - s.disconnectWithStreamError(streamerror.ErrNotAuthorized) + s.disconnectWithStreamError(ctx, streamerror.ErrNotAuthorized) return } if len(elem.Namespace()) > 0 && elem.Namespace() != tlsNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidNamespace) return } s.setSecured(true) - s.writeElement(xmpp.NewElementNamespace("proceed", tlsNamespace)) + s.writeElement(ctx, xmpp.NewElementNamespace("proceed", tlsNamespace)) - s.cfg.transport.StartTLS(&tls.Config{Certificates: s.router.Certificates()}, false) + s.tr.StartTLS(&tls.Config{Certificates: s.router.Hosts().Certificates()}, false) log.Infof("secured stream... id: %s", s.id) s.restartSession() } -func (s *inStream) compress(elem xmpp.XElement) { +func (s *inStream) compress(ctx context.Context, elem xmpp.XElement) { if s.isCompressed() { - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) return } method := elem.Elements().Child("method") if method == nil || len(method.Text()) == 0 { failure := xmpp.NewElementNamespace("failure", compressProtocolNamespace) failure.AppendElement(xmpp.NewElementName("setup-failed")) - s.writeElement(failure) + s.writeElement(ctx, failure) return } if method.Text() != "zlib" { failure := xmpp.NewElementNamespace("failure", compressProtocolNamespace) failure.AppendElement(xmpp.NewElementName("unsupported-method")) - s.writeElement(failure) + s.writeElement(ctx, failure) return } - s.writeElement(xmpp.NewElementNamespace("compressed", compressProtocolNamespace)) + s.writeElement(ctx, xmpp.NewElementNamespace("compressed", compressProtocolNamespace)) - s.cfg.transport.EnableCompression(s.cfg.compression.Level) + s.tr.EnableCompression(s.cfg.compression.Level) s.setCompressed(true) log.Infof("compressed stream... id: %s", s.id) @@ -527,19 +478,19 @@ func (s *inStream) compress(elem xmpp.XElement) { s.restartSession() } -func (s *inStream) startAuthentication(elem xmpp.XElement) { +func (s *inStream) startAuthentication(ctx context.Context, elem xmpp.XElement) { if elem.Namespace() != saslNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidNamespace) return } mechanism := elem.Attributes().Get("mechanism") for _, authenticator := range s.authenticators { if authenticator.Mechanism() == mechanism { - if err := s.continueAuthentication(elem, authenticator); err != nil { + if err := s.continueAuthentication(ctx, elem, authenticator); err != nil { return } if authenticator.Authenticated() { - s.finishAuthentication(authenticator.Username()) + s.finishAuthentication(ctx, authenticator.Username()) } else { s.activeAuth = authenticator s.setState(authenticating) @@ -550,21 +501,21 @@ func (s *inStream) startAuthentication(elem xmpp.XElement) { // ...mechanism not found... failure := xmpp.NewElementNamespace("failure", saslNamespace) failure.AppendElement(xmpp.NewElementName("invalid-mechanism")) - s.writeElement(failure) + s.writeElement(ctx, failure) } -func (s *inStream) continueAuthentication(elem xmpp.XElement, authr auth.Authenticator) error { - err := authr.ProcessElement(elem) +func (s *inStream) continueAuthentication(ctx context.Context, elem xmpp.XElement, authr auth.Authenticator) error { + err := authr.ProcessElement(ctx, elem) if saslErr, ok := err.(*auth.SASLError); ok { - s.failAuthentication(saslErr.Element()) + s.failAuthentication(ctx, saslErr.Element()) } else if err != nil { log.Error(err) - s.failAuthentication(auth.ErrSASLTemporaryAuthFailure.(*auth.SASLError).Element()) + s.failAuthentication(ctx, auth.ErrSASLTemporaryAuthFailure.(*auth.SASLError).Element()) } return err } -func (s *inStream) finishAuthentication(username string) { +func (s *inStream) finishAuthentication(_ context.Context, username string) { if s.activeAuth != nil { s.activeAuth.Reset() s.activeAuth = nil @@ -576,10 +527,10 @@ func (s *inStream) finishAuthentication(username string) { s.restartSession() } -func (s *inStream) failAuthentication(elem xmpp.XElement) { +func (s *inStream) failAuthentication(ctx context.Context, elem xmpp.XElement) { failure := xmpp.NewElementNamespace("failure", saslNamespace) failure.AppendElement(elem) - s.writeElement(failure) + s.writeElement(ctx, failure) if s.activeAuth != nil { s.activeAuth.Reset() @@ -588,21 +539,21 @@ func (s *inStream) failAuthentication(elem xmpp.XElement) { s.setState(connected) } -func (s *inStream) bindResource(iq *xmpp.IQ) { +func (s *inStream) bindResource(ctx context.Context, iq *xmpp.IQ) { bind := iq.Elements().ChildNamespace("bind", bindNamespace) if bind == nil { - s.writeElement(iq.NotAllowedError()) + s.writeElement(ctx, iq.NotAllowedError()) return } var resource string if resourceElem := bind.Elements().Child("resource"); resourceElem != nil { resource = resourceElem.Text() } else { - resource = uuid.New() + resource = uuid.New().String() } // try binding... var stm stream.C2S - streams := s.router.UserStreams(s.JID().Node()) + streams := s.router.LocalStreams(s.JID().Node()) for _, s := range streams { if s.Resource() == resource { stm = s @@ -612,19 +563,19 @@ func (s *inStream) bindResource(iq *xmpp.IQ) { switch s.cfg.resourceConflict { case Override: // override the resource with a server-generated resourcepart... - resource = uuid.New() + resource = uuid.New().String() case Replace: // terminate the session of the currently connected client... - stm.Disconnect(streamerror.ErrResourceConstraint) + stm.Disconnect(ctx, streamerror.ErrResourceConstraint) default: // disallow resource binding attempt... - s.writeElement(iq.ConflictError()) + s.writeElement(ctx, iq.ConflictError()) return } } userJID, err := jid.New(s.Username(), s.Domain(), resource, false) if err != nil { - s.writeElement(iq.BadRequestError()) + s.writeElement(ctx, iq.BadRequestError()) return } s.setJID(userJID) @@ -634,7 +585,7 @@ func (s *inStream) bindResource(iq *xmpp.IQ) { s.presence = xmpp.NewPresence(userJID, userJID, xmpp.UnavailableType) s.mu.Unlock() - s.router.Bind(s) + s.router.Bind(ctx, s) //...notify successful binding result := xmpp.NewIQType(iq.ID(), xmpp.ResultType) @@ -647,7 +598,7 @@ func (s *inStream) bindResource(iq *xmpp.IQ) { result.AppendElement(boundElem) s.setState(bound) - s.writeElement(result) + s.writeElement(ctx, result) // start pinging... if p := s.mods.Ping; p != nil { @@ -655,73 +606,85 @@ func (s *inStream) bindResource(iq *xmpp.IQ) { } } -func (s *inStream) processStanza(elem xmpp.Stanza) { +func (s *inStream) processStanza(ctx context.Context, elem xmpp.Stanza) { toJID := elem.ToJID() - if s.isBlockedJID(toJID) { // blocked JID? + if s.isBlockedJID(ctx, toJID) { // blocked JID? blocked := xmpp.NewElementNamespace("blocked", blockedErrorNamespace) resp := xmpp.NewErrorStanzaFromStanza(elem, xmpp.ErrNotAcceptable, []xmpp.XElement{blocked}) - s.writeElement(resp) + s.writeElement(ctx, resp) return } switch stanza := elem.(type) { case *xmpp.Presence: - s.processPresence(stanza) + s.processPresence(ctx, stanza) case *xmpp.IQ: - s.processIQ(stanza) + s.processIQ(ctx, stanza) case *xmpp.Message: - s.processMessage(stanza) + s.processMessage(ctx, stanza) } } -func (s *inStream) processIQ(iq *xmpp.IQ) { +func (s *inStream) processIQ(ctx context.Context, iq *xmpp.IQ) { toJID := iq.ToJID() - - replyOnBehalf := !toJID.IsFullWithUser() && s.router.IsLocalHost(toJID.Domain()) + replyOnBehalf := !toJID.IsFullWithUser() && (s.router.Hosts().IsLocalHost(toJID.Domain())) || + s.router.Hosts().IsConferenceHost(toJID.Domain()) if !replyOnBehalf { - switch s.router.Route(iq) { + switch s.router.Route(ctx, iq) { case router.ErrResourceNotFound: - s.writeElement(iq.ServiceUnavailableError()) + s.writeElement(ctx, iq.ServiceUnavailableError()) case router.ErrFailedRemoteConnect: - s.writeElement(iq.RemoteServerNotFoundError()) + s.writeElement(ctx, iq.RemoteServerNotFoundError()) case router.ErrBlockedJID: // destination user is a blocked JID if iq.IsGet() || iq.IsSet() { - s.writeElement(iq.ServiceUnavailableError()) + s.writeElement(ctx, iq.ServiceUnavailableError()) } } return } - s.mods.ProcessIQ(iq) + s.mods.ProcessIQ(ctx, iq) } -func (s *inStream) processPresence(presence *xmpp.Presence) { +func (s *inStream) processPresence(ctx context.Context, presence *xmpp.Presence) { + // is the presence stanza directed to the conference service + if s.router.Hosts().IsConferenceHost(presence.ToJID().Domain()) { + s.mods.Muc.ProcessPresence(ctx, presence) + return + } + if presence.ToJID().IsFullWithUser() { - _ = s.router.Route(presence) + _ = s.router.Route(ctx, presence) return } - replyOnBehalf := s.JID().Matches(presence.ToJID(), jid.MatchesBare) + replyOnBehalf := s.JID().MatchesWithOptions(presence.ToJID(), jid.MatchesBare) // update presence if replyOnBehalf && (presence.IsAvailable() || presence.IsUnavailable()) { s.setPresence(presence) } - // deliver presence to roster module + // process presence if r := s.mods.Roster; r != nil { - r.ProcessPresence(presence) + r.ProcessPresence(ctx, presence) } + // deliver offline messages if replyOnBehalf && presence.IsAvailable() && presence.Priority() >= 0 { if off := s.mods.Offline; off != nil { - off.DeliverOfflineMessages(s) + off.DeliverOfflineMessages(ctx, s) } } } -func (s *inStream) processMessage(message *xmpp.Message) { +func (s *inStream) processMessage(ctx context.Context, message *xmpp.Message) { msg := message + if s.router.Hosts().IsConferenceHost(message.ToJID().Domain()) { + s.mods.Muc.ProcessMessage(ctx, message) + return + } + sendMessage: - err := s.router.Route(msg) + err := s.router.Route(ctx, msg) switch err { case nil: break @@ -731,14 +694,14 @@ sendMessage: goto sendMessage case router.ErrNotAuthenticated: if off := s.mods.Offline; off != nil { - off.ArchiveMessage(message) + off.ArchiveMessage(ctx, message) return } fallthrough case router.ErrNotExistingAccount, router.ErrBlockedJID: - s.writeElement(message.ServiceUnavailableError()) + s.writeElement(ctx, message.ServiceUnavailableError()) case router.ErrFailedRemoteConnect: - s.writeElement(message.RemoteServerNotFoundError()) + s.writeElement(ctx, message.RemoteServerNotFoundError()) default: log.Error(err) } @@ -746,83 +709,89 @@ sendMessage: // Runs on it's own goroutine func (s *inStream) doRead() { + s.scheduleReadTimeout() elem, sErr := s.sess.Receive() + s.cancelReadTimeout() + + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) if sErr == nil { - s.runQueue.Run(func() { s.readElement(elem) }) + s.runQueue.Run(func() { s.readElement(ctx, elem) }) } else { s.runQueue.Run(func() { if s.getState() == disconnected { return } - s.handleSessionError(sErr) + s.handleSessionError(ctx, sErr) }) } } -func (s *inStream) handleSessionError(sErr *session.Error) { +func (s *inStream) handleSessionError(ctx context.Context, sErr *session.Error) { switch err := sErr.UnderlyingErr.(type) { case nil: - s.disconnect(nil) + s.disconnect(ctx, nil) case *streamerror.Error: - s.disconnectWithStreamError(err) + s.disconnectWithStreamError(ctx, err) case *xmpp.StanzaError: - s.writeStanzaErrorResponse(sErr.Element, err) + s.writeStanzaErrorResponse(ctx, sErr.Element, err) default: log.Error(err) - s.disconnectWithStreamError(streamerror.ErrUndefinedCondition) + s.disconnectWithStreamError(ctx, streamerror.ErrUndefinedCondition) } } -func (s *inStream) writeStanzaErrorResponse(elem xmpp.XElement, stanzaErr *xmpp.StanzaError) { +func (s *inStream) writeStanzaErrorResponse(ctx context.Context, elem xmpp.XElement, stanzaErr *xmpp.StanzaError) { resp := xmpp.NewElementFromElement(elem) resp.SetType(xmpp.ErrorType) resp.SetFrom(resp.To()) resp.SetTo(s.JID().String()) resp.AppendElement(stanzaErr.Element()) - s.writeElement(resp) + s.writeElement(ctx, resp) } -func (s *inStream) writeElement(elem xmpp.XElement) { - s.sess.Send(elem) +func (s *inStream) writeElement(ctx context.Context, elem xmpp.XElement) { + if err := s.sess.Send(ctx, elem); err != nil { + log.Error(err) + } } -func (s *inStream) readElement(elem xmpp.XElement) { +func (s *inStream) readElement(ctx context.Context, elem xmpp.XElement) { if elem != nil { - s.handleElement(elem) + s.handleElement(ctx, elem) } if s.getState() != disconnected { - go s.doRead() // Keep reading... + go s.doRead() // keep reading... } } -func (s *inStream) disconnect(err error) { +func (s *inStream) disconnect(ctx context.Context, err error) { if s.getState() == disconnected { return } switch err { case nil: - s.disconnectClosingSession(false, true) + s.disconnectClosingSession(ctx, false, true) default: if stmErr, ok := err.(*streamerror.Error); ok { - s.disconnectWithStreamError(stmErr) + s.disconnectWithStreamError(ctx, stmErr) } else { log.Error(err) - s.disconnectClosingSession(false, true) + s.disconnectClosingSession(ctx, false, true) } } } -func (s *inStream) disconnectWithStreamError(err *streamerror.Error) { +func (s *inStream) disconnectWithStreamError(ctx context.Context, err *streamerror.Error) { if s.getState() == connecting { - _ = s.sess.Open(nil) + _ = s.sess.Open(ctx, nil) } - s.writeElement(err.Element()) + s.writeElement(ctx, err.Element()) unregister := err != streamerror.ErrSystemShutdown - s.disconnectClosingSession(true, unregister) + s.disconnectClosingSession(ctx, true, unregister) } -func (s *inStream) disconnectClosingSession(closeSession, unbind bool) { +func (s *inStream) disconnectClosingSession(ctx context.Context, closeSession, unbind bool) { // stop pinging... if p := s.mods.Ping; p != nil { p.CancelPing(s) @@ -830,76 +799,62 @@ func (s *inStream) disconnectClosingSession(closeSession, unbind bool) { // send 'unavailable' presence when disconnecting if presence := s.Presence(); presence != nil && presence.IsAvailable() { if r := s.mods.Roster; r != nil { - r.ProcessPresence(xmpp.NewPresence(s.JID(), s.JID().ToBareJID(), xmpp.UnavailableType)) + r.ProcessPresence(ctx, xmpp.NewPresence(s.JID(), s.JID().ToBareJID(), xmpp.UnavailableType)) } } if closeSession { - _ = s.sess.Close() + _ = s.sess.Close(ctx) } // unregister stream if unbind { - s.router.Unbind(s.JID()) + s.router.Unbind(ctx, s.JID()) } + s.ctxCancelFn() + // notify disconnection if s.cfg.onDisconnect != nil { s.cfg.onDisconnect(s) } s.setState(disconnected) - _ = s.cfg.transport.Close() + _ = s.tr.Close() s.runQueue.Stop(nil) // stop processing messages } -func (s *inStream) isBlockedJID(j *jid.JID) bool { - if j.IsServer() && s.router.IsLocalHost(j.Domain()) { +func (s *inStream) isBlockedJID(ctx context.Context, j *jid.JID) bool { + blockList, err := s.blockListRep.FetchBlockListItems(ctx, s.Username()) + if err != nil { + log.Error(err) + return false + } + if len(blockList) == 0 { return false } - return s.router.IsBlockedJID(j, s.Username()) + blockListJIDs := make([]jid.JID, len(blockList)) + for i, listItem := range blockList { + j, _ := jid.NewWithString(listItem.JID, true) + blockListJIDs[i] = *j + } + for _, blockedJID := range blockListJIDs { + if blockedJID.Matches(j) { + return true + } + } + return false } func (s *inStream) restartSession() { s.sess = session.New(s.id, &session.Config{ JID: s.JID(), - Transport: s.cfg.transport, MaxStanzaSize: s.cfg.maxStanzaSize, - }, s.router) + }, s.tr, s.router.Hosts()) s.setState(connecting) } -func (s *inStream) setContextValue(key string, value interface{}) { - s.contextMu.Lock() - defer s.contextMu.Unlock() - s.context[key] = value - - // notify the whole roster about the context update. - if c := s.router.Cluster(); c != nil { - c.BroadcastMessage(&cluster.Message{ - Type: cluster.MsgUpdateContext, - Node: c.LocalNode(), - Payloads: []cluster.MessagePayload{{ - JID: s.JID(), - Context: map[string]interface{}{key: value}, - }}, - }) - } -} - func (s *inStream) setPresence(presence *xmpp.Presence) { s.mu.Lock() - defer s.mu.Unlock() s.presence = presence - - // notify the whole roster about the presence update. - if c := s.router.Cluster(); c != nil { - c.BroadcastMessage(&cluster.Message{ - Type: cluster.MsgUpdatePresence, - Node: c.LocalNode(), - Payloads: []cluster.MessagePayload{{ - JID: s.jid, - Stanza: presence, - }}, - }) - } + s.mu.Unlock() } func (s *inStream) setJID(j *jid.JID) { @@ -944,6 +899,25 @@ func (s *inStream) setSessionStarted(sessStarted bool) { s.sessStarted = sessStarted } +func (s *inStream) scheduleReadTimeout() { + s.mu.Lock() + s.readTimeoutTm = time.AfterFunc(s.cfg.keepAlive, s.readTimeout) + s.mu.Unlock() +} + +func (s *inStream) cancelReadTimeout() { + s.mu.Lock() + s.readTimeoutTm.Stop() + s.mu.Unlock() +} + +func (s *inStream) readTimeout() { + s.runQueue.Run(func() { + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) + s.disconnect(ctx, streamerror.ErrConnectionTimeout) + }) +} + func (s *inStream) setState(state uint32) { atomic.StoreUint32(&s.state, state) } diff --git a/c2s/in_test.go b/c2s/in_test.go index e32453b4b..18dcd3bec 100644 --- a/c2s/in_test.go +++ b/c2s/in_test.go @@ -6,49 +6,48 @@ package c2s import ( + "context" "testing" "time" + "github.com/google/uuid" "github.com/ortuman/jackal/component" "github.com/ortuman/jackal/model" "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/router" "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/transport" "github.com/ortuman/jackal/transport/compress" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" - "github.com/pborman/uuid" "github.com/stretchr/testify/require" ) func TestStream_ConnectTimeout(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - stm, _ := tUtilStreamInit(r) + stm, _ := tUtilStreamInit(r, userRep, blockListRep) time.Sleep(time.Millisecond * 1500) require.Equal(t, disconnected, stm.getState()) } func TestStream_Disconnect(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - stm, conn := tUtilStreamInit(r) - stm.Disconnect(nil) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) + stm.Disconnect(context.Background(), nil) require.True(t, conn.waitClose()) require.Equal(t, disconnected, stm.getState()) } func TestStream_Features(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") // unsecured features - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) elem := conn.outboundRead() @@ -61,7 +60,7 @@ func TestStream_Features(t *testing.T) { require.Equal(t, connected, stm.getState()) // secured features - stm2, conn2 := tUtilStreamInit(r) + stm2, conn2 := tUtilStreamInit(r, userRep, blockListRep) stm2.setSecured(true) tUtilStreamOpen(conn2) @@ -75,18 +74,17 @@ func TestStream_Features(t *testing.T) { } func TestStream_TLS(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... - conn.inboundWrite([]byte(``)) + _, _ = conn.inboundWrite([]byte(``)) elem := conn.outboundRead() @@ -97,34 +95,28 @@ func TestStream_TLS(t *testing.T) { } func TestStream_FailAuthenticate(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - _, conn := tUtilStreamInit(r) + _, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... // wrong mechanism - conn.inboundWrite([]byte(``)) + _, _ = conn.inboundWrite([]byte(``)) elem := conn.outboundRead() require.Equal(t, "failure", elem.Name()) - conn.inboundWrite([]byte(``)) - - elem = conn.outboundRead() - require.Equal(t, "challenge", elem.Name()) - - conn.inboundWrite([]byte(`dXNlcm5hbWU9Im9ydHVtYW4iLHJlYWxtPSJsb2NhbGhvc3QiLG5vbmNlPSJuY3prcXJFb3Uyait4ek1pcUgxV1lBdHh6dlNCSzFVbHNOejNLQUJsSjd3PSIsY25vbmNlPSJlcHNMSzhFQU8xVWVFTUpLVjdZNXgyYUtqaHN2UXpSMGtIdFM0ZGljdUFzPSIsbmM9MDAwMDAwMDEsZGlnZXN0LXVyaT0ieG1wcC9sb2NhbGhvc3QiLHFvcD1hdXRoLHJlc3BvbnNlPTVmODRmNTk2YWE4ODc0OWY2ZjZkZTYyZjliNjhkN2I2LGNoYXJzZXQ9dXRmLTg=`)) + _, _ = conn.inboundWrite([]byte(`AHVzZXIAYQ==`)) elem = conn.outboundRead() require.Equal(t, "failure", elem.Name()) // non-SASL - conn.inboundWrite([]byte(` + _, _ = conn.inboundWrite([]byte(` bill Calli0pe @@ -137,12 +129,11 @@ func TestStream_FailAuthenticate(t *testing.T) { } func TestStream_Compression(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -154,13 +145,13 @@ func TestStream_Compression(t *testing.T) { _ = conn.outboundRead() // read stream features... // no method... - conn.inboundWrite([]byte(``)) + _, _ = conn.inboundWrite([]byte(``)) elem := conn.outboundRead() require.Equal(t, "failure", elem.Name()) require.NotNil(t, elem.Elements().Child("setup-failed")) // invalid method... - conn.inboundWrite([]byte(` + _, _ = conn.inboundWrite([]byte(` 7z `)) elem = conn.outboundRead() @@ -168,7 +159,7 @@ func TestStream_Compression(t *testing.T) { require.NotNil(t, elem.Elements().Child("unsupported-method")) // valid method... - conn.inboundWrite([]byte(` + _, _ = conn.inboundWrite([]byte(` zlib `)) @@ -176,16 +167,17 @@ func TestStream_Compression(t *testing.T) { require.Equal(t, "compressed", elem.Name()) require.Equal(t, "http://jabber.org/protocol/compress", elem.Namespace()) + time.Sleep(time.Millisecond * 100) // wait until processed... + require.True(t, stm.isCompressed()) } func TestStream_StartSession(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -203,12 +195,11 @@ func TestStream_StartSession(t *testing.T) { } func TestStream_SendIQ(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -225,27 +216,27 @@ func TestStream_SendIQ(t *testing.T) { require.Equal(t, bound, stm.getState()) // request roster... - iqID := uuid.New() + iqID := uuid.New().String() iq := xmpp.NewIQType(iqID, xmpp.GetType) iq.AppendElement(xmpp.NewElementNamespace("query", "jabber:iq:roster")) - conn.inboundWrite([]byte(iq.String())) + _, _ = conn.inboundWrite([]byte(iq.String())) elem := conn.outboundRead() require.Equal(t, "iq", elem.Name()) require.Equal(t, iqID, elem.ID()) require.NotNil(t, elem.Elements().ChildNamespace("query", "jabber:iq:roster")) - require.True(t, stm.GetBool("roster:requested")) + requested, _ := stm.Value("roster:requested").(bool) + require.True(t, requested) } func TestStream_SendPresence(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -261,7 +252,7 @@ func TestStream_SendPresence(t *testing.T) { require.Equal(t, bound, stm.getState()) - conn.inboundWrite([]byte(` + _, _ = conn.inboundWrite([]byte(` away away! @@ -285,12 +276,11 @@ func TestStream_SendPresence(t *testing.T) { } func TestStream_SendMessage(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -310,9 +300,11 @@ func TestStream_SendMessage(t *testing.T) { jTo, _ := jid.New("ortuman", "localhost", "garden", true) stm2 := stream.NewMockC2S("abcd7890", jTo) - r.Bind(stm2) + stm2.SetPresence(xmpp.NewPresence(jTo, jTo, xmpp.AvailableType)) + + r.Bind(context.Background(), stm2) - msgID := uuid.New() + msgID := uuid.New().String() msg := xmpp.NewMessageType(msgID, xmpp.ChatType) msg.SetFromJID(jFrom) msg.SetToJID(jTo) @@ -320,7 +312,7 @@ func TestStream_SendMessage(t *testing.T) { body.SetText("Hi buddy!") msg.AppendElement(body) - conn.inboundWrite([]byte(msg.String())) + _, _ = conn.inboundWrite([]byte(msg.String())) // to full jid... elem := stm2.ReceiveElement() @@ -329,19 +321,18 @@ func TestStream_SendMessage(t *testing.T) { // to bare jid... msg.SetToJID(jTo.ToBareJID()) - conn.inboundWrite([]byte(msg.String())) + _, _ = conn.inboundWrite([]byte(msg.String())) elem = stm2.ReceiveElement() require.Equal(t, "message", elem.Name()) require.Equal(t, msgID, elem.ID()) } func TestStream_SendToBlockedJID(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, userRep, blockListRep := setupTest("localhost") - storage.InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit(r) + stm, conn := tUtilStreamInit(r, userRep, blockListRep) tUtilStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -357,13 +348,13 @@ func TestStream_SendToBlockedJID(t *testing.T) { require.Equal(t, bound, stm.getState()) - storage.InsertBlockListItems([]model.BlockListItem{{ + _ = blockListRep.InsertBlockListItem(context.Background(), &model.BlockListItem{ Username: "user", JID: "hamlet@localhost", - }}) + }) // send presence to a blocked JID... - conn.inboundWrite([]byte(``)) + _, _ = conn.inboundWrite([]byte(``)) elem := conn.outboundRead() require.Equal(t, "presence", elem.Name()) @@ -376,29 +367,19 @@ func tUtilStreamOpen(conn *fakeSocketConn) { ` - conn.inboundWrite([]byte(s)) + _, _ = conn.inboundWrite([]byte(s)) } func tUtilStreamAuthenticate(conn *fakeSocketConn, t *testing.T) { - conn.inboundWrite([]byte(``)) + _, _ = conn.inboundWrite([]byte(`AHVzZXIAcGVuY2ls`)) elem := conn.outboundRead() - require.Equal(t, "challenge", elem.Name()) - - conn.inboundWrite([]byte(`dXNlcm5hbWU9InVzZXIiLHJlYWxtPSJsb2NhbGhvc3QiLG5vbmNlPSJuY3prcXJFb3Uyait4ek1pcUgxV1lBdHh6dlNCSzFVbHNOejNLQUJsSjd3PSIsY25vbmNlPSJlcHNMSzhFQU8xVWVFTUpLVjdZNXgyYUtqaHN2UXpSMGtIdFM0ZGljdUFzPSIsbmM9MDAwMDAwMDEsZGlnZXN0LXVyaT0ieG1wcC9sb2NhbGhvc3QiLHFvcD1hdXRoLHJlc3BvbnNlPTVmODRmNTk2YWE4ODc0OWY2ZjZkZTYyZjliNjhkN2I2LGNoYXJzZXQ9dXRmLTg=`)) - - elem = conn.outboundRead() - require.Equal(t, "challenge", elem.Name()) - - conn.inboundWrite([]byte(``)) - - elem = conn.outboundRead() require.Equal(t, "success", elem.Name()) } func tUtilStreamBind(conn *fakeSocketConn, t *testing.T) { // bind a resource - conn.inboundWrite([]byte(` + _, _ = conn.inboundWrite([]byte(` balcony @@ -411,7 +392,7 @@ func tUtilStreamBind(conn *fakeSocketConn, t *testing.T) { func tUtilStreamStartSession(conn *fakeSocketConn, t *testing.T) { // open session - conn.inboundWrite([]byte(` + _, _ = conn.inboundWrite([]byte(` `)) @@ -422,17 +403,25 @@ func tUtilStreamStartSession(conn *fakeSocketConn, t *testing.T) { time.Sleep(time.Millisecond * 100) // wait until stream internal state changes } -func tUtilStreamInit(r *router.Router) (*inStream, *fakeSocketConn) { +func tUtilStreamInit(r router.Router, userRep repository.User, blockListRep repository.BlockList) (*inStream, *fakeSocketConn) { conn := newFakeSocketConn() - tr := transport.NewSocketTransport(conn, 4096) - stm := newStream("abcd1234", tUtilInStreamDefaultConfig(tr), tUtilInitModules(r), &component.Components{}, r) + tr := transport.NewSocketTransport(conn) + stm := newStream( + "abcd1234", + tUtilInStreamDefaultConfig(), + tr, + tUtilInitModules(r), + &component.Components{}, + r, + userRep, + blockListRep) return stm.(*inStream), conn } -func tUtilInStreamDefaultConfig(tr transport.Transport) *streamConfig { +func tUtilInStreamDefaultConfig() *streamConfig { return &streamConfig{ connectTimeout: time.Second, - transport: tr, + keepAlive: time.Second, maxStanzaSize: 8192, resourceConflict: Reject, compression: CompressConfig{Level: compress.DefaultCompression}, @@ -440,10 +429,11 @@ func tUtilInStreamDefaultConfig(tr transport.Transport) *streamConfig { } } -func tUtilInitModules(r *router.Router) *module.Modules { +func tUtilInitModules(r router.Router) *module.Modules { modules := map[string]struct{}{} modules["roster"] = struct{}{} modules["blocking_command"] = struct{}{} - return module.New(&module.Config{Enabled: modules}, r) + repContainer, _ := storage.New(&storage.Config{Type: storage.Memory}) + return module.New(&module.Config{Enabled: modules}, r, repContainer, "alloc-1234") } diff --git a/c2s/router/resources.go b/c2s/router/resources.go new file mode 100644 index 000000000..185a968c7 --- /dev/null +++ b/c2s/router/resources.go @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package c2srouter + +import ( + "context" + "sync" + + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" +) + +type resources struct { + mu sync.RWMutex + streams []stream.C2S +} + +func (r *resources) len() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.streams) +} + +func (r *resources) allStreams() []stream.C2S { + r.mu.RLock() + defer r.mu.RUnlock() + return r.streams +} + +func (r *resources) stream(resource string) stream.C2S { + r.mu.RLock() + defer r.mu.RUnlock() + + for _, stm := range r.streams { + if stm.Resource() == resource { + return stm + } + } + return nil +} + +func (r *resources) bind(stm stream.C2S) { + r.mu.Lock() + defer r.mu.Unlock() + + res := stm.Resource() + for _, s := range r.streams { + if s.Resource() == res { + return + } + } + r.streams = append(r.streams, stm) +} + +func (r *resources) unbind(res string) { + r.mu.Lock() + defer r.mu.Unlock() + + for i, s := range r.streams { + if s.Resource() != res { + continue + } + r.streams = append(r.streams[:i], r.streams[i+1:]...) + return + } +} + +func (r *resources) route(ctx context.Context, stanza xmpp.Stanza) error { + toJID := stanza.ToJID() + if toJID.IsFullWithUser() { + for _, stm := range r.streams { + if p := stm.Presence(); p != nil && p.IsAvailable() && stm.Resource() == toJID.Resource() { + stm.SendElement(ctx, stanza) + return nil + } + } + return router.ErrResourceNotFound + } + switch stanza.(type) { + case *xmpp.Message: + // send to highest priority stream + var highestPriority int8 + var recipient stream.C2S + + for _, stm := range r.streams { + if p := stm.Presence(); p != nil && p.IsAvailable() && p.Priority() > highestPriority { + recipient = stm + highestPriority = p.Priority() + } + } + if recipient == nil { + goto broadcast + } + recipient.SendElement(ctx, stanza) + return nil + } + +broadcast: + // broadcast toJID all streams + for _, stm := range r.streams { + if p := stm.Presence(); p != nil && p.IsAvailable() { + stm.SendElement(ctx, stanza) + } + } + return nil +} diff --git a/c2s/router/resources_test.go b/c2s/router/resources_test.go new file mode 100644 index 000000000..34a594e9a --- /dev/null +++ b/c2s/router/resources_test.go @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package c2srouter + +import ( + "context" + "testing" + + "github.com/google/uuid" + + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestResources_Binding(t *testing.T) { + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + stm := stream.NewMockC2S("id-1", j) + + res := resources{} + require.Equal(t, 0, res.len()) + + res.bind(stm) + require.Equal(t, 1, res.len()) + + require.NotNil(t, res.stream("yard")) + require.Len(t, res.allStreams(), 1) + + res.unbind("yard") + + require.Nil(t, res.stream("yard")) + require.Len(t, res.allStreams(), 0) +} + +func TestResources_Route(t *testing.T) { + j1, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + j2, _ := jid.NewWithString("ortuman@jackal.im/balcony", true) + j3, _ := jid.NewWithString("ortuman@jackal.im/chamber", true) + j4, _ := jid.NewWithString("ortuman@jackal.im", true) + + stm1 := stream.NewMockC2S("id-1", j1) + stm2 := stream.NewMockC2S("id-2", j2) + + stm1.SetPresence(xmpp.NewPresence(j1.ToBareJID(), j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2.ToBareJID(), j2, xmpp.AvailableType)) + + res := resources{} + res.bind(stm1) + res.bind(stm2) + + msgID := uuid.New().String() + msg := xmpp.NewMessageType(msgID, xmpp.NormalType) + msg.SetFromJID(j1) + msg.SetToJID(j3) + + err := res.route(context.Background(), msg) + require.Equal(t, router.ErrResourceNotFound, err) + + msg.SetToJID(j2) + err = res.route(context.Background(), msg) + require.Nil(t, err) + + elem := stm2.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "message", elem.Name()) + require.Equal(t, msgID, elem.ID()) + + msgID = uuid.New().String() + msg = xmpp.NewMessageType(msgID, xmpp.NormalType) + msg.SetFromJID(j1) + msg.SetToJID(j4) + + err = res.route(context.Background(), msg) + require.Nil(t, err) + + elem1 := stm1.ReceiveElement() + elem2 := stm2.ReceiveElement() + + require.Equal(t, "message", elem1.Name()) + require.Equal(t, elem1.ID(), elem2.ID()) + require.Equal(t, elem1.Name(), elem2.Name()) +} diff --git a/c2s/router/router.go b/c2s/router/router.go new file mode 100644 index 000000000..5d145a389 --- /dev/null +++ b/c2s/router/router.go @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package c2srouter + +import ( + "context" + "sync" + + "github.com/ortuman/jackal/log" + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +type c2sRouter struct { + mu sync.RWMutex + tbl map[string]*resources + userRep repository.User + blockListRep repository.BlockList +} + +func New(userRep repository.User, blockListRep repository.BlockList) router.C2SRouter { + return &c2sRouter{ + tbl: make(map[string]*resources), + userRep: userRep, + blockListRep: blockListRep, + } +} + +func (r *c2sRouter) Route(ctx context.Context, stanza xmpp.Stanza, validateStanza bool) error { + fromJID := stanza.FromJID() + toJID := stanza.ToJID() + + // validate if sender JID is blocked + if validateStanza && r.isBlockedJID(ctx, toJID, fromJID.Node()) { + return router.ErrBlockedJID + } + username := stanza.ToJID().Node() + r.mu.RLock() + rs := r.tbl[username] + r.mu.RUnlock() + + if rs == nil { + exists, err := r.userRep.UserExists(ctx, username) + if err != nil { + return err + } + if exists { + return router.ErrNotAuthenticated + } + return router.ErrNotExistingAccount + } + return rs.route(ctx, stanza) +} + +func (r *c2sRouter) Bind(stm stream.C2S) { + user := stm.Username() + r.mu.RLock() + rs := r.tbl[user] + r.mu.RUnlock() + + if rs == nil { + r.mu.Lock() + rs = r.tbl[user] // avoid double initialization + if rs == nil { + rs = &resources{} + r.tbl[user] = rs + } + r.mu.Unlock() + } + rs.bind(stm) + + log.Infof("bound c2s stream... (%s/%s)", stm.Username(), stm.Resource()) +} + +func (r *c2sRouter) Unbind(user, resource string) { + r.mu.RLock() + rs := r.tbl[user] + r.mu.RUnlock() + + if rs == nil { + return + } + r.mu.Lock() + rs.unbind(resource) + if rs.len() == 0 { + delete(r.tbl, user) + } + r.mu.Unlock() + + log.Infof("unbound c2s stream... (%s/%s)", user, resource) +} + +func (r *c2sRouter) Stream(username, resource string) stream.C2S { + r.mu.RLock() + rs := r.tbl[username] + r.mu.RUnlock() + + if rs == nil { + return nil + } + return rs.stream(resource) +} + +func (r *c2sRouter) Streams(username string) []stream.C2S { + r.mu.RLock() + rs := r.tbl[username] + r.mu.RUnlock() + + if rs == nil { + return nil + } + return rs.allStreams() +} + +func (r *c2sRouter) isBlockedJID(ctx context.Context, j *jid.JID, username string) bool { + blockList, err := r.blockListRep.FetchBlockListItems(ctx, username) + if err != nil { + log.Error(err) + return false + } + if len(blockList) == 0 { + return false + } + blockListJIDs := make([]jid.JID, len(blockList)) + for i, listItem := range blockList { + j, _ := jid.NewWithString(listItem.JID, true) + blockListJIDs[i] = *j + } + for _, blockedJID := range blockListJIDs { + if blockedJID.Matches(j) { + return true + } + } + return false +} diff --git a/c2s/router/router_test.go b/c2s/router/router_test.go new file mode 100644 index 000000000..31af4e3db --- /dev/null +++ b/c2s/router/router_test.go @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package c2srouter + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/model" + "github.com/ortuman/jackal/router" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestRouter_Binding(t *testing.T) { + j1, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + j2, _ := jid.NewWithString("ortuman@jackal.im/balcony", true) + + stm1 := stream.NewMockC2S("id-1", j1) + stm2 := stream.NewMockC2S("id-1", j2) + + r, _, _ := setupTest() + + r.Bind(stm1) + r.Bind(stm2) + stm1.SetPresence(xmpp.NewPresence(j1.ToBareJID(), j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2.ToBareJID(), j2, xmpp.AvailableType)) + + require.Len(t, r.Streams("ortuman"), 2) + + require.NotNil(t, r.Stream("ortuman", "yard")) + require.NotNil(t, r.Stream("ortuman", "balcony")) + + r.Unbind("ortuman", "yard") + r.Unbind("ortuman", "balcony") + + require.Len(t, r.Streams("ortuman"), 0) + + r.(*c2sRouter).mu.RLock() + require.Len(t, r.(*c2sRouter).tbl, 0) + r.(*c2sRouter).mu.RUnlock() +} + +func TestRouter_Routing(t *testing.T) { + j1, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + j2, _ := jid.NewWithString("romeo@jackal.im/deadlyresource", true) + stm1 := stream.NewMockC2S("id-1", j1) + stm2 := stream.NewMockC2S("id-2", j2) + + r, userRep, blockListRep := setupTest() + + err := r.Route(context.Background(), xmpp.NewPresence(j1, j1, xmpp.AvailableType), true) + require.Equal(t, router.ErrNotExistingAccount, err) + + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "ortuman"}) + _ = userRep.UpsertUser(context.Background(), &model.User{Username: "romeo"}) + + err = r.Route(context.Background(), xmpp.NewPresence(j1, j1, xmpp.AvailableType), true) + require.Equal(t, router.ErrNotAuthenticated, err) + + r.Bind(stm1) + stm1.SetPresence(xmpp.NewPresence(j1.ToBareJID(), j1, xmpp.AvailableType)) + + err = r.Route(context.Background(), xmpp.NewPresence(j1, j1, xmpp.AvailableType), true) + require.Nil(t, err) + + // block jid + r.Bind(stm2) + stm2.SetPresence(xmpp.NewPresence(j2.ToBareJID(), j2, xmpp.AvailableType)) + + _ = blockListRep.InsertBlockListItem(context.Background(), &model.BlockListItem{ + Username: "ortuman", + JID: "jackal.im/deadlyresource", + }) + + err = r.Route(context.Background(), xmpp.NewPresence(j1.ToBareJID(), j2, xmpp.AvailableType), true) + require.Equal(t, router.ErrBlockedJID, err) +} + +func setupTest() (router.C2SRouter, repository.User, repository.BlockList) { + userRep := memorystorage.NewUser() + blockListRep := memorystorage.NewBlockList() + return New(userRep, blockListRep), userRep, blockListRep +} diff --git a/c2s/server.go b/c2s/server.go index e9d78f742..783bdf46c 100644 --- a/c2s/server.go +++ b/c2s/server.go @@ -7,20 +7,19 @@ package c2s import ( "context" - "crypto/tls" "fmt" "net" - "net/http" "strconv" "sync" "sync/atomic" + "time" - "github.com/gorilla/websocket" "github.com/ortuman/jackal/component" streamerror "github.com/ortuman/jackal/errors" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/transport" ) @@ -28,16 +27,29 @@ import ( var listenerProvider = net.Listen type server struct { - cfg *Config - mods *module.Modules - comps *component.Components - router *router.Router - inConns sync.Map - ln net.Listener - wsSrv *http.Server - wsUpgrader *websocket.Upgrader - stmSeq uint64 - listening uint32 + cfg *Config + mods *module.Modules + comps *component.Components + router router.Router + userRep repository.User + blockListRep repository.BlockList + inConnectionsMu sync.Mutex + inConnections map[string]stream.C2S + ln net.Listener + stmSeq uint64 + listening uint32 +} + +func newC2SServer(config *Config, mods *module.Modules, comps *component.Components, router router.Router, userRep repository.User, blockListRep repository.BlockList) c2sServer { + return &server{ + cfg: config, + mods: mods, + comps: comps, + router: router, + userRep: userRep, + blockListRep: blockListRep, + inConnections: make(map[string]stream.C2S), + } } func (s *server) start() { @@ -51,9 +63,6 @@ func (s *server) start() { switch s.cfg.Transport.Type { case transport.Socket: err = s.listenSocketConn(address) - case transport.WebSocket: - err = s.listenWebSocketConn(address) - break } if err != nil { log.Fatalf("%v", err) @@ -71,40 +80,13 @@ func (s *server) listenSocketConn(address string) error { for atomic.LoadUint32(&s.listening) == 1 { conn, err := ln.Accept() if err == nil { - go s.startStream(transport.NewSocketTransport(conn, s.cfg.Transport.KeepAlive)) + go s.startStream(transport.NewSocketTransport(conn), s.cfg.KeepAlive) continue } } return nil } -func (s *server) listenWebSocketConn(address string) error { - http.HandleFunc(s.cfg.Transport.URLPath, s.websocketUpgrade) - - s.wsSrv = &http.Server{TLSConfig: &tls.Config{Certificates: s.router.Certificates()}} - s.wsUpgrader = &websocket.Upgrader{ - Subprotocols: []string{"xmpp"}, - CheckOrigin: func(r *http.Request) bool { return r.Header.Get("Sec-WebSocket-Protocol") == "xmpp" }, - } - - // start listening - ln, err := listenerProvider("tcp", address) - if err != nil { - return err - } - atomic.StoreUint32(&s.listening, 1) - return s.wsSrv.ServeTLS(ln, "", "") -} - -func (s *server) websocketUpgrade(w http.ResponseWriter, r *http.Request) { - conn, err := s.wsUpgrader.Upgrade(w, r, nil) - if err != nil { - log.Error(err) - return - } - s.startStream(transport.NewWebSocketTransport(conn, s.cfg.Transport.KeepAlive)) -} - func (s *server) shutdown(ctx context.Context) error { if atomic.CompareAndSwapUint32(&s.listening, 1, 0) { // stop listening @@ -113,13 +95,9 @@ func (s *server) shutdown(ctx context.Context) error { if err := s.ln.Close(); err != nil { return err } - case transport.WebSocket: - if err := s.wsSrv.Shutdown(ctx); err != nil { - return err - } } // close all connections - c, err := closeConnections(ctx, &s.inConns) + c, err := s.closeConnections(ctx) if err != nil { return err } @@ -128,27 +106,34 @@ func (s *server) shutdown(ctx context.Context) error { return nil } -func (s *server) startStream(tr transport.Transport) { +func (s *server) startStream(tr transport.Transport, keepAlive time.Duration) { cfg := &streamConfig{ - transport: tr, resourceConflict: s.cfg.ResourceConflict, connectTimeout: s.cfg.ConnectTimeout, + keepAlive: s.cfg.KeepAlive, + timeout: s.cfg.Timeout, maxStanzaSize: s.cfg.MaxStanzaSize, sasl: s.cfg.SASL, compression: s.cfg.Compression, onDisconnect: s.unregisterStream, } - stm := newStream(s.nextID(), cfg, s.mods, s.comps, s.router) + stm := newStream(s.nextID(), cfg, tr, s.mods, s.comps, s.router, s.userRep, s.blockListRep) s.registerStream(stm) } func (s *server) registerStream(stm stream.C2S) { - s.inConns.Store(stm.ID(), stm) + s.inConnectionsMu.Lock() + s.inConnections[stm.ID()] = stm + s.inConnectionsMu.Unlock() + log.Infof("registered c2s stream... (id: %s)", stm.ID()) } func (s *server) unregisterStream(stm stream.C2S) { - s.inConns.Delete(stm.ID()) + s.inConnectionsMu.Lock() + delete(s.inConnections, stm.ID()) + s.inConnectionsMu.Unlock() + log.Infof("unregistered c2s stream... (id: %s)", stm.ID()) } @@ -156,26 +141,24 @@ func (s *server) nextID() string { return fmt.Sprintf("c2s:%s:%d", s.cfg.ID, atomic.AddUint64(&s.stmSeq, 1)) } -func closeConnections(ctx context.Context, connections *sync.Map) (count int, err error) { - connections.Range(func(_, v interface{}) bool { - stm := v.(stream.InStream) +func (s *server) closeConnections(ctx context.Context) (count int, err error) { + s.inConnectionsMu.Lock() + for _, stm := range s.inConnections { select { - case <-closeConn(stm): + case <-closeConn(ctx, stm): count++ - return true case <-ctx.Done(): - count = 0 - err = ctx.Err() - return false + return 0, ctx.Err() } - }) - return + } + s.inConnectionsMu.Unlock() + return count, nil } -func closeConn(stm stream.InStream) <-chan bool { +func closeConn(ctx context.Context, stm stream.InStream) <-chan bool { c := make(chan bool, 1) go func() { - stm.Disconnect(streamerror.ErrSystemShutdown) + stm.Disconnect(ctx, streamerror.ErrSystemShutdown) c <- true }() return c diff --git a/c2s/server_test.go b/c2s/server_test.go index bc52e6fb7..3a1a8efdb 100644 --- a/c2s/server_test.go +++ b/c2s/server_test.go @@ -7,26 +7,19 @@ package c2s import ( "context" - "crypto/tls" "net" - "net/http" "testing" "time" - "github.com/gorilla/websocket" "github.com/ortuman/jackal/component" "github.com/ortuman/jackal/module" - "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/transport" - "github.com/ortuman/jackal/util" "github.com/stretchr/testify/require" ) func TestC2SSocketServer(t *testing.T) { - r, _, shutdown := setupTest("localhost") - defer shutdown() + r, _, _ := setupTest("localhost") errCh := make(chan error) cfg := Config{ @@ -39,7 +32,13 @@ func TestC2SSocketServer(t *testing.T) { Port: 9998, }, } - srv := server{cfg: &cfg, router: r, mods: &module.Modules{}, comps: &component.Components{}} + srv := server{ + cfg: &cfg, + router: r, + mods: &module.Modules{}, + comps: &component.Components{}, + inConnections: make(map[string]stream.C2S), + } go srv.start() go func() { @@ -64,68 +63,9 @@ func TestC2SSocketServer(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*5)) defer cancel() - srv.shutdown(ctx) + _ = srv.shutdown(ctx) errCh <- nil }() err := <-errCh require.Nil(t, err) } - -func TestC2SWebSocketServer(t *testing.T) { - privKeyFile := "../testdata/cert/test.server.key" - certFile := "../testdata/cert/test.server.crt" - cer, err := util.LoadCertificate(privKeyFile, certFile, "localhost") - require.Nil(t, err) - - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: "localhost", Certificate: cer}}, - }) - s := memstorage.New() - storage.Set(s) - defer storage.Unset() - - errCh := make(chan error) - cfg := Config{ - ID: "srv-1234", - ConnectTimeout: time.Second * time.Duration(5), - MaxStanzaSize: 8192, - ResourceConflict: Reject, - Transport: TransportConfig{ - Type: transport.WebSocket, - URLPath: "/xmpp/ws", - Port: 9999, - }, - } - srv := server{cfg: &cfg, router: r, mods: &module.Modules{}, comps: &component.Components{}} - go srv.start() - - go func() { - time.Sleep(time.Millisecond * 150) - d := &websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - h := http.Header{"Sec-WebSocket-Protocol": []string{"xmpp"}} - conn, _, err := d.Dial("wss://127.0.0.1:9999/xmpp/ws", h) - if err != nil { - errCh <- err - return - } - open := []byte(``) - err = conn.WriteMessage(websocket.TextMessage, open) - if err != nil { - errCh <- err - return - } - - time.Sleep(time.Millisecond * 150) // wait until disconnected - - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*5)) - defer cancel() - - srv.shutdown(ctx) - errCh <- nil - }() - err = <-errCh - require.Nil(t, err) -} diff --git a/cluster/c2s.go b/cluster/c2s.go deleted file mode 100644 index 9a11dfcf1..000000000 --- a/cluster/c2s.go +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "sync" - - "github.com/ortuman/jackal/xmpp" - "github.com/ortuman/jackal/xmpp/jid" -) - -type c2sCluster interface { - LocalNode() string - SendMessageTo(node string, msg *Message) -} - -// C2S represents a cluster c2s stream. -type C2S struct { - identifier string - cluster c2sCluster - node string - jid *jid.JID - mu sync.RWMutex - presence *xmpp.Presence - contextMu sync.RWMutex - context map[string]interface{} -} - -func newC2S( - identifier string, - jid *jid.JID, - presence *xmpp.Presence, - context map[string]interface{}, - node string, - cluster c2sCluster) *C2S { - s := &C2S{ - identifier: identifier, - cluster: cluster, - node: node, - jid: jid, - presence: presence, - context: context, - } - return s -} - -// ID returns stream identifier. -func (s *C2S) ID() string { - return s.identifier -} - -// Context returns a copy of the stream associated context. -func (s *C2S) Context() map[string]interface{} { - m := make(map[string]interface{}) - s.contextMu.RLock() - for k, v := range s.context { - m[k] = v - } - s.contextMu.RUnlock() - return m -} - -// SetString associates a string context value to a key. -func (s *C2S) SetString(key string, value string) {} - -// GetString returns the context value associated with the key as a string. -func (s *C2S) GetString(key string) string { - var ret string - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if s, ok := s.context[key].(string); ok { - ret = s - } - return ret -} - -// SetInt associates an integer context value to a key. -func (s *C2S) SetInt(key string, value int) {} - -// GetInt returns the context value associated with the key as an integer. -func (s *C2S) GetInt(key string) int { - var ret int - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if i, ok := s.context[key].(int); ok { - ret = i - } - return ret -} - -// SetFloat associates a float context value to a key. -func (s *C2S) SetFloat(key string, value float64) {} - -// GetFloat returns the context value associated with the key as a float64. -func (s *C2S) GetFloat(key string) float64 { - var ret float64 - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if f, ok := s.context[key].(float64); ok { - ret = f - } - return ret -} - -// SetBool associates a boolean context value to a key. -func (s *C2S) SetBool(key string, value bool) {} - -// GetBool returns the context value associated with the key as a boolean. -func (s *C2S) GetBool(key string) bool { - var ret bool - s.contextMu.RLock() - defer s.contextMu.RUnlock() - if b, ok := s.context[key].(bool); ok { - ret = b - } - return ret -} - -// UpdateContext updates stream context by copying all 'm' values -func (s *C2S) UpdateContext(m map[string]interface{}) { - s.contextMu.Lock() - for k, v := range m { - s.context[k] = v - } - s.contextMu.Unlock() -} - -// Username returns current stream username. -func (s *C2S) Username() string { - return s.jid.Node() -} - -// Domain returns current stream domain. -func (s *C2S) Domain() string { - return s.jid.Domain() -} - -// Resource returns current stream resource. -func (s *C2S) Resource() string { - return s.jid.Resource() -} - -// JID returns current user JID. -func (s *C2S) JID() *jid.JID { - return s.jid -} - -// IsAuthenticated returns whether or not the XMPP stream has successfully authenticated. -func (s *C2S) IsAuthenticated() bool { return true } - -// IsSecured returns whether or not the XMPP stream has been secured using SSL/TLS. -func (s *C2S) IsSecured() bool { return true } - -// Presence returns last sent presence element. -func (s *C2S) Presence() *xmpp.Presence { - s.mu.RLock() - defer s.mu.RUnlock() - return s.presence -} - -// SetPresence updates the C2S stream presence. -func (s *C2S) SetPresence(presence *xmpp.Presence) { - s.mu.Lock() - s.presence = presence - s.mu.Unlock() -} - -// Disconnect disconnects remote peer by closing the underlying TCP socket connection. -func (s *C2S) Disconnect(err error) {} - -// SendElement writes an XMPP element to the stream. -func (s *C2S) SendElement(elem xmpp.XElement) { - stanza, ok := elem.(xmpp.Stanza) - if !ok { - return - } - s.cluster.SendMessageTo(s.node, &Message{ - Type: MsgRouteStanza, - Node: s.cluster.LocalNode(), - Payloads: []MessagePayload{{ - JID: s.jid, - Stanza: stanza, - }}, - }) -} diff --git a/cluster/c2s_test.go b/cluster/c2s_test.go deleted file mode 100644 index 7c09273af..000000000 --- a/cluster/c2s_test.go +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "testing" - - "github.com/google/uuid" - "github.com/ortuman/jackal/xmpp" - "github.com/ortuman/jackal/xmpp/jid" - "github.com/stretchr/testify/require" -) - -type fakeC2SCluster struct { - sendMessageToCalls int -} - -func (c *fakeC2SCluster) LocalNode() string { return "node1" } -func (c *fakeC2SCluster) SendMessageTo(node string, msg *Message) { c.sendMessageToCalls++ } - -func TestC2S_New(t *testing.T) { - var c fakeC2SCluster - - id := uuid.New().String() - stm := newTestClusterC2S(id, "ortuman@jackal.im/balcony", xmpp.AvailableType, map[string]interface{}{}, "node1", &c) - - require.Equal(t, id, stm.ID()) - require.True(t, stm.IsSecured()) - require.True(t, stm.IsAuthenticated()) - - require.Equal(t, "ortuman", stm.Username()) - require.Equal(t, "jackal.im", stm.Domain()) - require.Equal(t, "balcony", stm.Resource()) - - j := stm.JID() - require.NotNil(t, j) - require.Equal(t, "ortuman", j.Node()) - require.Equal(t, "jackal.im", j.Domain()) - require.Equal(t, "balcony", j.Resource()) -} - -func TestC2S_Presence(t *testing.T) { - var c fakeC2SCluster - - id := uuid.New().String() - stm := newTestClusterC2S(id, "ortuman@jackal.im/balcony", xmpp.AvailableType, map[string]interface{}{}, "node1", &c) - - p := stm.Presence() - require.NotNil(t, p) - require.Equal(t, xmpp.AvailableType, p.Type()) - - // change presence - p = xmpp.NewPresence(p.FromJID(), p.ToJID(), xmpp.UnavailableType) - stm.SetPresence(p) - require.Equal(t, p, stm.Presence()) -} - -func TestC2S_Context(t *testing.T) { - var c fakeC2SCluster - - context := map[string]interface{}{ - "a1": true, - "b1": 3.14, - "c1": 35, - "d1": "foo", - } - contextLength := len(context) - - id := uuid.New().String() - stm := newTestClusterC2S(id, "ortuman@jackal.im/balcony", xmpp.AvailableType, context, "node1", &c) - - // setters don't do anything - stm.SetBool("a2", true) - stm.SetFloat("b2", 3.14) - stm.SetInt("c2", 35) - stm.SetString("d2", "foo") - - require.Equal(t, contextLength, len(stm.Context())) - - require.True(t, stm.GetBool("a1")) - require.Equal(t, 3.14, stm.GetFloat("b1")) - require.Equal(t, 35, stm.GetInt("c1")) - require.Equal(t, "foo", stm.GetString("d1")) - - // update context - stm.UpdateContext(map[string]interface{}{ - "e1": "foo2", - }) - - require.Equal(t, contextLength+1, len(stm.Context())) - require.Equal(t, "foo2", stm.GetString("e1")) -} - -func TestC2S_SendElement(t *testing.T) { - var c fakeC2SCluster - - id := uuid.New().String() - stm := newTestClusterC2S(id, "ortuman@jackal.im/balcony", xmpp.AvailableType, map[string]interface{}{}, "node1", &c) - - stm.SendElement(xmpp.NewElementName("vCard")) // not a stanza - stm.SendElement(xmpp.NewIQType(uuid.New().String(), xmpp.GetType)) - - require.Equal(t, 1, c.sendMessageToCalls) -} - -func newTestClusterC2S(id string, jidString string, presenceType string, context map[string]interface{}, node string, c2sCluster c2sCluster) *C2S { - j, _ := jid.NewWithString(jidString, true) - p := xmpp.NewPresence(j, j, xmpp.AvailableType) - return newC2S(id, j, p, context, node, c2sCluster) -} diff --git a/cluster/cluster.go b/cluster/cluster.go deleted file mode 100644 index 184c10a7e..000000000 --- a/cluster/cluster.go +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "bytes" - "sync" - - "github.com/ortuman/jackal/runqueue" - - "github.com/google/uuid" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/xmpp" - "github.com/ortuman/jackal/xmpp/jid" -) - -const clusterMailboxSize = 32768 - -var createMemberList = func(localName string, bindPort int, cluster *Cluster) (memberList, error) { - return newDefaultMemberList(localName, bindPort, cluster) -} - -// Metadata type represents all metadata information associated to a node. -type Metadata struct { - Version string - GoVersion string -} - -// Node represents a concrete c2s node and metadata information. -type Node struct { - Name string - Metadata Metadata -} - -// Delegate is the interface that will receive all c2s related events. -type Delegate interface { - NodeJoined(node *Node) - NodeUpdated(node *Node) - NodeLeft(node *Node) - - NotifyMessage(msg *Message) -} - -// memberList interface defines the common c2s member list methods. -type memberList interface { - Members() []Node - - Join(hosts []string) error - Shutdown() error - - SendReliable(node string, msg []byte) error -} - -// Cluster represents a c2s sub system. -type Cluster struct { - cfg *Config - buf *bytes.Buffer - delegate Delegate - memberList memberList - membersMu sync.RWMutex - members map[string]*Node - runQueue *runqueue.RunQueue -} - -// New returns an initialized c2s instance -func New(config *Config, delegate Delegate) (*Cluster, error) { - if config == nil { - return nil, nil - } - c := &Cluster{ - cfg: config, - delegate: delegate, - buf: bytes.NewBuffer(nil), - members: make(map[string]*Node), - runQueue: runqueue.New("cluster"), - } - ml, err := createMemberList(config.Name, config.BindPort, c) - if err != nil { - return nil, err - } - c.memberList = ml - return c, nil -} - -// Join tries to join the c2s by contacting all the given hosts. -func (c *Cluster) Join() error { - log.Infof("local node: %s", c.LocalNode()) - - c.membersMu.Lock() - for _, m := range c.memberList.Members() { - if m.Name == c.LocalNode() { - continue - } - log.Infof("registered cluster node: %s", m.Name) - c.members[m.Name] = &m - } - c.membersMu.Unlock() - return c.memberList.Join(c.cfg.Hosts) -} - -// LocalNode returns the local node identifier. -func (c *Cluster) LocalNode() string { - return c.cfg.Name -} - -// C2SStream returns a cluster C2S stream. -func (c *Cluster) C2SStream(jid *jid.JID, presence *xmpp.Presence, context map[string]interface{}, node string) *C2S { - return newC2S(uuid.New().String(), jid, presence, context, node, c) -} - -// SendMessageTo sends a cluster message to a concrete node. -func (c *Cluster) SendMessageTo(node string, msg *Message) { - c.runQueue.Run(func() { - if err := c.send(msg, node); err != nil { - log.Error(err) - return - } - }) -} - -// BroadcastMessage broadcasts a cluster message to all nodes. -func (c *Cluster) BroadcastMessage(msg *Message) { - c.runQueue.Run(func() { - if err := c.broadcast(msg); err != nil { - log.Error(err) - } - }) -} - -// Shutdown shuts down cluster sub system. -func (c *Cluster) Shutdown() error { - errCh := make(chan error, 1) - c.runQueue.Stop(func() { - errCh <- c.memberList.Shutdown() - }) - return <-errCh -} - -func (c *Cluster) send(msg *Message, toNode string) error { - return c.memberList.SendReliable(toNode, c.encodeMessage(msg)) -} - -func (c *Cluster) broadcast(msg *Message) error { - msgBytes := c.encodeMessage(msg) - - c.membersMu.RLock() - defer c.membersMu.RUnlock() - - for _, node := range c.members { - if node.Name == c.LocalNode() { - continue - } - if err := c.memberList.SendReliable(node.Name, msgBytes); err != nil { - return err - } - } - return nil -} - -func (c *Cluster) handleNotifyJoin(n *Node) { - if n.Name == c.LocalNode() { - return - } - c.membersMu.Lock() - c.members[n.Name] = n - c.membersMu.Unlock() - - log.Infof("registered cluster node: %s", n.Name) - if c.delegate != nil && n.Name != c.LocalNode() { - c.delegate.NodeJoined(n) - } -} - -func (c *Cluster) handleNotifyUpdate(n *Node) { - if n.Name == c.LocalNode() { - return - } - c.membersMu.Lock() - c.members[n.Name] = n - c.membersMu.Unlock() - - log.Infof("updated cluster node: %s", n.Name) - if c.delegate != nil && n.Name != c.LocalNode() { - c.delegate.NodeUpdated(n) - } -} - -func (c *Cluster) handleNotifyLeave(n *Node) { - if n.Name == c.LocalNode() { - return - } - c.membersMu.Lock() - delete(c.members, n.Name) - c.membersMu.Unlock() - - log.Infof("unregistered cluster node: %s", n.Name) - if c.delegate != nil && n.Name != c.LocalNode() { - c.delegate.NodeLeft(n) - } -} - -func (c *Cluster) handleNotifyMsg(msg []byte) { - if len(msg) == 0 { - return - } - var m Message - buf := bytes.NewBuffer(msg) - if err := m.FromBytes(buf); err != nil { - log.Error(err) - return - } - if c.delegate != nil { - c.delegate.NotifyMessage(&m) - } -} - -func (c *Cluster) encodeMessage(msg *Message) []byte { - defer c.buf.Reset() - - _ = msg.ToBytes(c.buf) - msgBytes := make([]byte, c.buf.Len(), c.buf.Len()) - copy(msgBytes, c.buf.Bytes()) - return msgBytes -} diff --git a/cluster/cluster_test.go b/cluster/cluster_test.go deleted file mode 100644 index a7fbbbe77..000000000 --- a/cluster/cluster_test.go +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "bytes" - "errors" - "testing" - "time" - - "github.com/ortuman/jackal/xmpp/jid" - "github.com/stretchr/testify/require" -) - -const clusterOpTimeout = time.Millisecond * 250 - -type fakeClusterDelegate struct { - nodeJoinedCalls int - nodeUpdatedCalls int - nodeLeftCalls int - notifyMessageCalls int -} - -func (d *fakeClusterDelegate) NodeJoined(node *Node) { d.nodeJoinedCalls++ } -func (d *fakeClusterDelegate) NodeUpdated(node *Node) { d.nodeUpdatedCalls++ } -func (d *fakeClusterDelegate) NodeLeft(node *Node) { d.nodeLeftCalls++ } -func (d *fakeClusterDelegate) NotifyMessage(msg *Message) { d.notifyMessageCalls++ } - -type fakeMemberList struct { - members []Node - joinHosts []string - sendErr error - sendCh chan []byte - shutdownCh chan struct{} - membersCalls int - joinCalls int - shutdownCalls int - sendReliableCalls int -} - -func (ml *fakeMemberList) Members() []Node { - ml.membersCalls++ - return ml.members -} - -func (ml *fakeMemberList) Join(hosts []string) error { - ml.joinHosts = hosts - ml.joinCalls++ - return nil -} - -func (ml *fakeMemberList) Shutdown() error { - if ml.shutdownCh != nil { - close(ml.shutdownCh) - } - ml.shutdownCalls++ - return nil -} - -func (ml *fakeMemberList) SendReliable(node string, msg []byte) error { - if ml.sendErr != nil { - return ml.sendErr - } - if ml.sendCh != nil { - ml.sendCh <- msg - } - ml.sendReliableCalls++ - return nil -} - -func TestCluster_Create(t *testing.T) { - var ml fakeMemberList - createMemberList = func(_ string, _ int, _ *Cluster) (list memberList, e error) { - return &ml, nil - } - c, _ := New(nil, nil) - require.Nil(t, c) - - c, _ = New(testClusterConfig(), nil) - require.NotNil(t, c) - require.Equal(t, "node1", c.LocalNode()) -} - -func TestCluster_Shutdown(t *testing.T) { - var ml fakeMemberList - createMemberList = func(_ string, _ int, _ *Cluster) (list memberList, e error) { - return &ml, nil - } - c, _ := New(testClusterConfig(), nil) - require.NotNil(t, c) - - ml.shutdownCh = make(chan struct{}) - - _ = c.Shutdown() - select { - case <-ml.shutdownCh: - break - case <-time.After(clusterOpTimeout): - require.Fail(t, "cluster shutdown timeout") - } -} - -func TestCluster_Join(t *testing.T) { - var ml fakeMemberList - createMemberList = func(_ string, _ int, _ *Cluster) (list memberList, e error) { - return &ml, nil - } - c, _ := New(testClusterConfig(), nil) - require.NotNil(t, c) - - ml.members = []Node{{Name: "node2"}, {Name: "node3"}} - err := c.Join() - require.Nil(t, err) - - require.Equal(t, 1, ml.membersCalls) - require.Equal(t, 1, ml.joinCalls) - - require.Equal(t, 2, len(ml.joinHosts)) -} - -func TestCluster_SendAndBroadcast(t *testing.T) { - var ml fakeMemberList - createMemberList = func(_ string, _ int, _ *Cluster) (list memberList, e error) { - return &ml, nil - } - c, _ := New(testClusterConfig(), nil) - require.NotNil(t, c) - - ml.members = []Node{{Name: "node2"}, {Name: "node3"}} - err := c.Join() - require.Nil(t, err) - - ml.sendCh = make(chan []byte) - c.SendMessageTo("node3", &Message{}) - select { - case <-ml.sendCh: - break - case <-time.After(clusterOpTimeout): - require.Fail(t, "cluster send message timeout") - } - - c.BroadcastMessage(&Message{}) - - for i := 0; i < 2; i++ { - select { - case <-ml.sendCh: - break - case <-time.After(clusterOpTimeout): - require.Fail(t, "cluster broadcast message timeout") - } - } - - // test send error - ml.sendErr = errors.New("cluster: send error") - - c.SendMessageTo("node3", &Message{}) - select { - case <-ml.sendCh: - require.Fail(t, "unexpected send message") - case <-time.After(clusterOpTimeout): - break - } - - c.BroadcastMessage(&Message{}) - - for i := 0; i < 2; i++ { - select { - case <-ml.sendCh: - require.Fail(t, "unexpected broadcast message") - case <-time.After(clusterOpTimeout): - break - } - } -} - -func TestCluster_Delegate(t *testing.T) { - var ml fakeMemberList - var delegate fakeClusterDelegate - - createMemberList = func(_ string, _ int, _ *Cluster) (list memberList, e error) { - return &ml, nil - } - c, _ := New(testClusterConfig(), &delegate) - require.NotNil(t, c) - - c.handleNotifyJoin(&Node{Name: "node4"}) - require.Equal(t, 1, delegate.nodeJoinedCalls) - - c.handleNotifyUpdate(&Node{Name: "node4"}) - require.Equal(t, 1, delegate.nodeUpdatedCalls) - - c.handleNotifyLeave(&Node{Name: "node4"}) - require.Equal(t, 1, delegate.nodeLeftCalls) - - j, _ := jid.NewWithString("ortuman@jackal.im/garden", true) - var m Message - m.Type = MsgBind - m.Node = "node3" - m.Payloads = []MessagePayload{{JID: j}} - - buf := bytes.NewBuffer(nil) - require.Nil(t, m.ToBytes(buf)) - - c.handleNotifyMsg(buf.Bytes()) - require.Equal(t, 1, delegate.notifyMessageCalls) -} - -func testClusterConfig() *Config { - return &Config{ - Name: "node1", - BindPort: 9999, - Hosts: []string{"127.0.0.1:6666", "127.0.0.1:7777"}, - } -} diff --git a/cluster/config.go b/cluster/config.go deleted file mode 100644 index d99253ac1..000000000 --- a/cluster/config.go +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -// Config represents an cluster configuration. -type Config struct { - Name string `yaml:"name"` - BindPort int `yaml:"port"` - Hosts []string `yaml:"hosts"` -} diff --git a/cluster/member_list_test.go b/cluster/member_list_test.go deleted file mode 100644 index 33c28bac8..000000000 --- a/cluster/member_list_test.go +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "bytes" - "encoding/gob" - "errors" - "runtime" - "testing" - "time" - - "github.com/hashicorp/memberlist" - "github.com/ortuman/jackal/version" - "github.com/stretchr/testify/require" -) - -type fakeHashicorpMemberList struct { - err error - joinCalls int - leaveCalls int - shutdownCalls int - sendReliableCalls int -} - -func (ml *fakeHashicorpMemberList) Join(existing []string) (int, error) { - if ml.err != nil { - return 0, ml.err - } - ml.joinCalls++ - return len(existing), nil -} - -func (ml *fakeHashicorpMemberList) Leave(timeout time.Duration) error { - if ml.err != nil { - return ml.err - } - ml.leaveCalls++ - return nil -} - -func (ml *fakeHashicorpMemberList) Shutdown() error { - if ml.err != nil { - return ml.err - } - ml.shutdownCalls++ - return nil -} - -func (ml *fakeHashicorpMemberList) SendReliable(to *memberlist.Node, msg []byte) error { - if ml.err != nil { - return ml.err - } - ml.sendReliableCalls++ - return nil -} - -type fakeMemberListDelegate struct { - notifyMsgCalls int - notifyJoinCalls int - notifyUpdateCalls int - notifyLeaveCalls int -} - -func (d *fakeMemberListDelegate) handleNotifyMsg(msg []byte) { d.notifyMsgCalls++ } -func (d *fakeMemberListDelegate) handleNotifyJoin(n *Node) { d.notifyJoinCalls++ } -func (d *fakeMemberListDelegate) handleNotifyUpdate(n *Node) { d.notifyUpdateCalls++ } -func (d *fakeMemberListDelegate) handleNotifyLeave(n *Node) { d.notifyLeaveCalls++ } - -func TestClusterMemberList_Members(t *testing.T) { - var ml fakeHashicorpMemberList - var delegate fakeMemberListDelegate - - createHashicorpMemberList = func(_ *memberlist.Config) (list hashicorpMemberList, e error) { - return &ml, nil - } - cMemberList, _ := newDefaultMemberList("node1", 6666, &delegate) - cMemberList.NotifyJoin(memberListNode("node1")) - cMemberList.NotifyJoin(memberListNode("node2")) - cMemberList.NotifyJoin(memberListNode("node3")) - - // no metadata included... node won't be added - cMemberList.NotifyJoin(&memberlist.Node{Name: "node4"}) - - require.Equal(t, 3, delegate.notifyJoinCalls) - - cMemberList.NotifyUpdate(&memberlist.Node{Name: "node2"}) - cMemberList.NotifyUpdate(memberListNode("node2")) - - require.Equal(t, 1, delegate.notifyUpdateCalls) - - cMemberList.NotifyLeave(&memberlist.Node{Name: "node3"}) - cMemberList.NotifyLeave(memberListNode("node3")) - - require.Equal(t, 1, delegate.notifyLeaveCalls) - - require.Equal(t, 2, len(cMemberList.Members())) -} - -func TestClusterMemberList_Join(t *testing.T) { - var ml fakeHashicorpMemberList - var delegate fakeMemberListDelegate - - createHashicorpMemberList = func(_ *memberlist.Config) (list hashicorpMemberList, e error) { - return &ml, nil - } - cMemberList, _ := newDefaultMemberList("node1", 6666, &delegate) - - err := cMemberList.Join([]string{"127.0.0.1:7777", "127.0.0.1:8888"}) - require.Nil(t, err) - require.Equal(t, 1, ml.joinCalls) - - ml.err = errors.New("") - err = cMemberList.Join([]string{"127.0.0.1:7777", "127.0.0.1:8888"}) - require.NotNil(t, err) - require.Equal(t, 1, ml.joinCalls) -} - -func TestClusterMemberList_Shutdown(t *testing.T) { - var ml fakeHashicorpMemberList - var delegate fakeMemberListDelegate - - createHashicorpMemberList = func(_ *memberlist.Config) (list hashicorpMemberList, e error) { - return &ml, nil - } - cMemberList, _ := newDefaultMemberList("node1", 6666, &delegate) - err := cMemberList.Shutdown() - require.Nil(t, err) - require.Equal(t, 1, ml.leaveCalls) - require.Equal(t, 1, ml.shutdownCalls) - - ml.err = errors.New("") - err = cMemberList.Shutdown() - require.NotNil(t, err) - require.Equal(t, 1, ml.leaveCalls) - require.Equal(t, 1, ml.shutdownCalls) -} - -func TestClusterMemberList_SendReliable(t *testing.T) { - var ml fakeHashicorpMemberList - var delegate fakeMemberListDelegate - - createHashicorpMemberList = func(_ *memberlist.Config) (list hashicorpMemberList, e error) { - return &ml, nil - } - cMemberList, _ := newDefaultMemberList("node1", 6666, &delegate) - err := cMemberList.SendReliable("node2", []byte{}) - require.NotNil(t, err) // node2 has not joined - require.Equal(t, 0, ml.sendReliableCalls) - - cMemberList.NotifyJoin(memberListNode("node2")) // node2 joins - - err = cMemberList.SendReliable("node2", []byte{}) - require.Nil(t, err) - require.Equal(t, 1, ml.sendReliableCalls) - - ml.err = errors.New("") - err = cMemberList.SendReliable("node2", []byte{}) - require.NotNil(t, err) - require.Equal(t, 1, ml.sendReliableCalls) -} - -func TestClusterMemberList_NodeMetadata(t *testing.T) { - var ml fakeHashicorpMemberList - var delegate fakeMemberListDelegate - - createHashicorpMemberList = func(_ *memberlist.Config) (list hashicorpMemberList, e error) { - return &ml, nil - } - cMemberList, _ := newDefaultMemberList("node1", 6666, &delegate) - require.Nil(t, cMemberList.NodeMeta(1)) - - b := cMemberList.NodeMeta(10000) - var meta Metadata - _ = gob.NewDecoder(bytes.NewReader(b)).Decode(&meta) - - require.Equal(t, meta.Version, version.ApplicationVersion.String()) - require.Equal(t, meta.GoVersion, runtime.Version()) -} - -func memberListNode(name string) *memberlist.Node { - var m Metadata - m.Version = version.ApplicationVersion.String() - m.GoVersion = runtime.Version() - - buf := bytes.NewBuffer(nil) - _ = gob.NewEncoder(buf).Encode(&m) - b := make([]byte, buf.Len()) - copy(b, buf.Bytes()) - return &memberlist.Node{ - Name: name, - Meta: b, - } -} diff --git a/cluster/memberlist.go b/cluster/memberlist.go deleted file mode 100644 index 93c6118dd..000000000 --- a/cluster/memberlist.go +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "bytes" - "encoding/gob" - "fmt" - "io/ioutil" - "runtime" - "sync" - "time" - - "github.com/hashicorp/memberlist" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/version" -) - -const leaveTimeout = time.Second * 5 - -var createHashicorpMemberList = func(conf *memberlist.Config) (hashicorpMemberList, error) { - return memberlist.Create(conf) -} - -type hashicorpMemberList interface { - Join(existing []string) (int, error) - - Leave(timeout time.Duration) error - Shutdown() error - - SendReliable(to *memberlist.Node, msg []byte) error -} - -type memberListDelegate interface { - handleNotifyMsg(msg []byte) - - handleNotifyJoin(n *Node) - handleNotifyUpdate(n *Node) - handleNotifyLeave(n *Node) -} - -type defaultMemberList struct { - delegate memberListDelegate - ml hashicorpMemberList - mu sync.RWMutex - members map[string]*memberlist.Node -} - -func newDefaultMemberList(localName string, bindPort int, delegate memberListDelegate) (*defaultMemberList, error) { - dl := &defaultMemberList{ - delegate: delegate, - members: make(map[string]*memberlist.Node), - } - conf := memberlist.DefaultLocalConfig() - conf.Name = localName - conf.BindPort = bindPort - conf.Delegate = dl - conf.Events = dl - conf.LogOutput = ioutil.Discard - - ml, err := createHashicorpMemberList(conf) - if err != nil { - return nil, err - } - dl.ml = ml - return dl, nil -} - -func (d *defaultMemberList) Members() []Node { - var res []Node - d.mu.RLock() - for _, n := range d.members { - cNode, err := d.clusterNodeFromMemberListNode(n) - if err != nil { - log.Warnf("%s", err) - continue - } - res = append(res, *cNode) - } - d.mu.RUnlock() - return res -} - -func (d *defaultMemberList) Join(hosts []string) error { - _, err := d.ml.Join(hosts) - return err -} - -func (d *defaultMemberList) Shutdown() error { - if err := d.ml.Leave(leaveTimeout); err != nil { - return err - } - return d.ml.Shutdown() -} - -func (d *defaultMemberList) SendReliable(toNode string, msg []byte) error { - d.mu.RLock() - defer d.mu.RUnlock() - node := d.members[toNode] - if node == nil { - return fmt.Errorf("cannot send message: node %s not found", toNode) - } - return d.ml.SendReliable(node, msg) -} - -// memberlist.Delegate - -func (d *defaultMemberList) NodeMeta(limit int) []byte { - var m Metadata - m.Version = version.ApplicationVersion.String() - m.GoVersion = runtime.Version() - - buf := bytes.NewBuffer(nil) - if err := gob.NewEncoder(buf).Encode(&m); err != nil { - log.Error(err) - return nil - } - if buf.Len() > limit { - log.Errorf("node metadata exceeds length limit of %d bytes", limit) - return nil - } - b := make([]byte, buf.Len()) - copy(b, buf.Bytes()) - return b -} - -func (d *defaultMemberList) NotifyMsg(msg []byte) { - d.delegate.handleNotifyMsg(msg) -} - -func (d *defaultMemberList) GetBroadcasts(overhead, limit int) [][]byte { return nil } -func (d *defaultMemberList) LocalState(join bool) []byte { return nil } -func (d *defaultMemberList) MergeRemoteState(buf []byte, join bool) {} - -// memberlist.EventDelegate - -func (d *defaultMemberList) NotifyJoin(n *memberlist.Node) { - d.mu.Lock() - d.members[n.Name] = n - d.mu.Unlock() - - cNode, err := d.clusterNodeFromMemberListNode(n) - if err != nil { - log.Warnf("%s", err) - return - } - d.delegate.handleNotifyJoin(cNode) -} - -func (d *defaultMemberList) NotifyLeave(n *memberlist.Node) { - d.mu.Lock() - delete(d.members, n.Name) - d.mu.Unlock() - - cNode, err := d.clusterNodeFromMemberListNode(n) - if err != nil { - log.Warnf("%s", err) - return - } - d.delegate.handleNotifyLeave(cNode) -} - -func (d *defaultMemberList) NotifyUpdate(n *memberlist.Node) { - d.mu.Lock() - d.members[n.Name] = n - d.mu.Unlock() - - cNode, err := d.clusterNodeFromMemberListNode(n) - if err != nil { - log.Warnf("%s", err) - return - } - d.delegate.handleNotifyUpdate(cNode) -} - -func (d *defaultMemberList) clusterNodeFromMemberListNode(n *memberlist.Node) (*Node, error) { - var m Metadata - if err := gob.NewDecoder(bytes.NewBuffer(n.Meta)).Decode(&m); err != nil { - return nil, err - } - return &Node{ - Name: n.Name, - Metadata: m, - }, nil -} diff --git a/cluster/message.go b/cluster/message.go deleted file mode 100644 index 5613b730e..000000000 --- a/cluster/message.go +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "bytes" - "encoding/gob" - - "github.com/ortuman/jackal/xmpp" - "github.com/ortuman/jackal/xmpp/jid" -) - -const ( - // MsgBind represents a bind cluster message. - MsgBind = iota - - // MsgBatchBind represents a batch bind cluster message. - MsgBatchBind - - // MsgUnbind represents a unbind cluster message. - MsgUnbind - - // MsgUpdatePresence represents an update presence cluster message. - MsgUpdatePresence - - // MsgUpdateContext represents a context update cluster message. - MsgUpdateContext - - // MsgRouteStanza represents a route stanza cluster message. - MsgRouteStanza -) - -const ( - messageStanza = iota - presenceStanza - iqStanza -) - -// MessagePayload represents a message payload type. -type MessagePayload struct { - JID *jid.JID - Context map[string]interface{} - Stanza xmpp.Stanza -} - -// FromBytes reads MessagePayload fields from its binary representation. -func (p *MessagePayload) FromBytes(buf *bytes.Buffer) error { - dec := gob.NewDecoder(buf) - j, err := jid.NewFromBytes(buf) - if err != nil { - return err - } - p.JID = j - - var hasContextMap bool - dec.Decode(&hasContextMap) - if hasContextMap { - var m map[string]interface{} - dec.Decode(&m) - p.Context = m - } - - var hasStanza bool - dec.Decode(&hasStanza) - if !hasStanza { - return nil - } - var stanzaType int - dec.Decode(&stanzaType) - switch stanzaType { - case messageStanza: - message, err := xmpp.NewMessageFromBytes(buf) - if err != nil { - return err - } - p.Stanza = message - case presenceStanza: - presence, err := xmpp.NewPresenceFromBytes(buf) - if err != nil { - return err - } - p.Stanza = presence - case iqStanza: - iq, err := xmpp.NewIQFromBytes(buf) - if err != nil { - return err - } - p.Stanza = iq - } - return nil -} - -// ToBytes converts a MessagePayload instance to its binary representation. -func (p *MessagePayload) ToBytes(buf *bytes.Buffer) error { - enc := gob.NewEncoder(buf) - if err := p.JID.ToBytes(buf); err != nil { - return err - } - - hasContextMap := p.Context != nil - if err := enc.Encode(&hasContextMap); err != nil { - return err - } - if hasContextMap { - if err := enc.Encode(&p.Context); err != nil { - return err - } - } - - hasStanza := p.Stanza != nil - if err := enc.Encode(&hasStanza); err != nil { - return err - } - if !hasStanza { - return nil - } - // store stanza type - switch p.Stanza.(type) { - case *xmpp.Message: - if err := enc.Encode(messageStanza); err != nil { - return err - } - case *xmpp.Presence: - if err := enc.Encode(presenceStanza); err != nil { - return err - } - case *xmpp.IQ: - if err := enc.Encode(iqStanza); err != nil { - return err - } - default: - return nil - } - return p.Stanza.ToBytes(buf) -} - -// Message is the c2s message type. -// A message can contain one or more payloads. -type Message struct { - Type int - Node string - Payloads []MessagePayload -} - -// FromBytes reads Message fields from its binary representation. -func (m *Message) FromBytes(buf *bytes.Buffer) error { - dec := gob.NewDecoder(buf) - if err := dec.Decode(&m.Type); err != nil { - return err - } - if err := dec.Decode(&m.Node); err != nil { - return err - } - - var pLen int - if err := dec.Decode(&pLen); err != nil { - return err - } - - m.Payloads = nil - for i := 0; i < pLen; i++ { - var p MessagePayload - if err := p.FromBytes(buf); err != nil { - return err - } - m.Payloads = append(m.Payloads, p) - } - return nil -} - -// ToBytes converts a Message instance to its binary representation. -func (m *Message) ToBytes(buf *bytes.Buffer) error { - enc := gob.NewEncoder(buf) - if err := enc.Encode(m.Type); err != nil { - return err - } - if err := enc.Encode(m.Node); err != nil { - return err - } - if err := enc.Encode(len(m.Payloads)); err != nil { - return err - } - for _, p := range m.Payloads { - if err := p.ToBytes(buf); err != nil { - return err - } - } - return nil -} diff --git a/cluster/message_test.go b/cluster/message_test.go deleted file mode 100644 index 5ae8023b0..000000000 --- a/cluster/message_test.go +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package cluster - -import ( - "bytes" - "testing" - - "github.com/google/uuid" - "github.com/ortuman/jackal/xmpp" - "github.com/ortuman/jackal/xmpp/jid" - "github.com/stretchr/testify/require" -) - -func TestMessageSerialization(t *testing.T) { - buf := bytes.NewBuffer(nil) - - var m1, m2 Message - m1 = Message{ - Type: MsgBatchBind, - Node: "node1", - } - require.Nil(t, m1.ToBytes(buf)) - - require.Nil(t, m2.FromBytes(buf)) - require.Equal(t, m1.Type, m2.Type) - require.Equal(t, m1.Node, m2.Node) - - j, _ := jid.NewWithString("ortuman@jackal.im", true) - m1 = Message{ - Type: MsgUpdatePresence, - Node: "node1", - Payloads: []MessagePayload{{ - JID: j, - Stanza: xmpp.NewPresence(j, j, xmpp.UnavailableType), - Context: map[string]interface{}{"requested": true}, - }}, - } - buf.Reset() - require.Nil(t, m1.ToBytes(buf)) - - require.Nil(t, m2.FromBytes(buf)) - require.Equal(t, m1.Type, m2.Type) - require.Equal(t, m1.Node, m2.Node) - require.Equal(t, 1, len(m2.Payloads)) - require.NotNil(t, m2.Payloads[0].JID) - require.NotNil(t, m2.Payloads[0].Stanza) - require.NotNil(t, m2.Payloads[0].Context) - - require.Equal(t, m1.Payloads[0].Context, m2.Payloads[0].Context) - require.Equal(t, m1.Payloads[0].JID.String(), m2.Payloads[0].JID.String()) - require.Equal(t, m1.Payloads[0].Stanza.String(), m2.Payloads[0].Stanza.String()) - _, ok := m2.Payloads[0].Stanza.(*xmpp.Presence) - require.True(t, ok) - - m1.Payloads[0].Stanza = xmpp.NewIQType(uuid.New().String(), xmpp.GetType) - buf.Reset() - require.Nil(t, m1.ToBytes(buf)) - - require.Nil(t, m2.FromBytes(buf)) - _, ok = m2.Payloads[0].Stanza.(*xmpp.IQ) - require.True(t, ok) - - m1.Payloads[0].Stanza = xmpp.NewMessageType(uuid.New().String(), xmpp.NormalType) - buf.Reset() - require.Nil(t, m1.ToBytes(buf)) - - require.Nil(t, m2.FromBytes(buf)) - _, ok = m2.Payloads[0].Stanza.(*xmpp.Message) - require.True(t, ok) -} diff --git a/component/component.go b/component/component.go index 364860f11..543bfb0db 100644 --- a/component/component.go +++ b/component/component.go @@ -18,7 +18,7 @@ import ( // Component represents a generic component interface. type Component interface { Host() string - ProcessStanza(stanza xmpp.Stanza, stm stream.C2S) + ProcessStanza(ctx context.Context, stanza xmpp.Stanza, stm stream.C2S) } // Components represents a set of preconfigured components. diff --git a/dockerfiles/jackal.yml b/dockerfiles/jackal.yml index 86d8a3f1c..90c17deb3 100644 --- a/dockerfiles/jackal.yml +++ b/dockerfiles/jackal.yml @@ -26,6 +26,7 @@ modules: - vcard # XEP-0054: vcard-temp - registration # XEP-0077: In-Band Registration - version # XEP-0092: Software Version + - pep # XEP-163: Personal Eventing Protocol - blocking_command # XEP-0191: Blocking Command - ping # XEP-0199: XMPP Ping - offline # Offline storage diff --git a/example.jackal.yml b/example.jackal.yml index 7f3c0eef0..f7b8b318e 100644 --- a/example.jackal.yml +++ b/example.jackal.yml @@ -18,6 +18,7 @@ storage: database: jackal pool_size: 16 +#storage: # type: pgsql # pgsql: # host: 127.0.0.1:5432 @@ -26,17 +27,11 @@ storage: # database: jackal # pool_size: 16 -# cluster: -# name: node1 -# port: 5010 -# hosts: [127.0.0.1:5009, 127.0.0.1:5011] - -router: - hosts: - - name: localhost - tls: - privkey_path: "" - cert_path: "" +hosts: + - name: localhost + tls: + privkey_path: "" + cert_path: "" modules: enabled: @@ -46,9 +41,11 @@ modules: - vcard # XEP-0054: vcard-temp - registration # XEP-0077: In-Band Registration - version # XEP-0092: Software Version + - pep # XEP-0163: Personal Eventing Protocol - blocking_command # XEP-0191: Blocking Command - ping # XEP-0199: XMPP Ping - offline # Offline storage + - muc # XEP-0045: Multi-User Chat mod_roster: versioning: true @@ -71,10 +68,30 @@ modules: send: no send_interval: 60 + mod_muc: + host: conference.jackal.im + name: "Chatroom Server" + room_defaults: + public: true + persistent: true + password_protected: false + open: true + moderated: false + allow_invites: false + allow_subject_change: true + enable_logging: true + non_anonymous: true + occupant_count: -1 # -1 means don't set the limit + # options for the next ones are "all", "moderators" and "" + can_get_member_list: "all" + send_pm: "all" + c2s: - id: default connect_timeout: 5 + keep_alive: 120 + max_stanza_size: 65536 resource_conflict: replace # [override, replace, reject] @@ -82,7 +99,6 @@ c2s: type: socket # websocket bind_addr: 0.0.0.0 port: 5222 - keep_alive: 120 # url_path: /xmpp/ws compression: @@ -90,17 +106,15 @@ c2s: sasl: - plain - - digest_md5 - scram_sha_1 - scram_sha_256 - - scram_sha_512 - -#s2s: -# dial_timeout: 15 -# dialback_secret: s3cr3tf0rd14lb4ck -# max_stanza_size: 131072 -# -# transport: -# bind_addr: 0.0.0.0 -# port: 5269 -# keep_alive: 600 + +s2s: + dial_timeout: 15 + keep_alive: 600 + dialback_secret: s3cr3tf0rd14lb4ck + max_stanza_size: 131072 + + transport: + bind_addr: 0.0.0.0 + port: 5269 diff --git a/go.mod b/go.mod index e7b52f4f3..86a21f237 100644 --- a/go.mod +++ b/go.mod @@ -1,60 +1,56 @@ module github.com/ortuman/jackal -go 1.12 +go 1.14 require ( - github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9 // indirect github.com/BurntSushi/toml v0.3.1 // indirect - github.com/DATA-DOG/go-sqlmock v0.0.0-20190322142548-ef0bdf231ae3 - github.com/Masterminds/squirrel v0.0.0-20190319150415-55303df43ec0 + github.com/DATA-DOG/go-sqlmock v1.3.3 + github.com/Masterminds/squirrel v1.1.0 github.com/antlr/antlr4 v0.0.0-20191011202612-ad2bd05285ca // indirect - github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878 // indirect github.com/bclicn/color v0.0.0-20180711051946-108f2023dc84 // indirect github.com/britram/borat v0.0.0-20181011130314-f891bcfcfb9b // indirect github.com/d4l3k/messagediff v1.2.1 // indirect github.com/dchest/cmac v0.0.0-20150527144652-62ff55a1048c // indirect - github.com/dgraph-io/badger v0.0.0-20190504012207-d2ebeac29495 - github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect - github.com/dustin/go-humanize v1.0.0 // indirect - github.com/go-sql-driver/mysql v0.0.0-20190423112050-d0a548181995 + github.com/fsnotify/fsnotify v1.4.9 // indirect + github.com/go-sql-driver/mysql v1.5.0 github.com/go-stack/stack v1.8.0 // indirect - github.com/google/btree v0.0.0-20190326150332-20236160a414 // indirect + github.com/golang/mock v1.4.1 // indirect + github.com/golang/protobuf v1.3.4 // indirect github.com/google/gopacket v1.1.17 // indirect - github.com/google/uuid v1.0.0 - github.com/gorilla/websocket v0.0.0-20190427040306-80c2d40e9b91 - github.com/hashicorp/go-msgpack v0.5.5 // indirect - github.com/hashicorp/go-sockaddr v1.0.2 // indirect - github.com/hashicorp/golang-lru v0.5.1 // indirect - github.com/hashicorp/memberlist v0.0.0-20190312092157-a8f83c6403e0 + github.com/google/uuid v1.1.1 github.com/inconshreveable/log15 v0.0.0-20180818164646-67afb5ed74ec // indirect github.com/kormat/fmt15 v0.0.0-20181112140556-ee69fecb2656 // indirect github.com/kylelemons/godebug v1.1.0 // indirect - github.com/lib/pq v0.0.0-20190504011754-ceb88a064902 + github.com/lib/pq v1.3.0 github.com/lucas-clemente/quic-go v0.0.0-20190427152327-c135b4f1e34c github.com/mattn/go-colorable v0.1.1 // indirect github.com/mattn/go-sqlite3 v1.11.0 // indirect github.com/netsec-ethz/rains v0.0.0-20190912114116-83f56a7cb2d1 // indirect github.com/netsec-ethz/scion-apps v0.0.0-20191003104124-7237654083b2 + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pborman/uuid v1.2.0 github.com/philhofer/fwd v1.0.0 // indirect github.com/pierrec/lz4 v1.0.1 // indirect github.com/pierrec/xxHash v0.1.5 // indirect - github.com/pkg/errors v0.8.1 + github.com/pkg/errors v0.9.1 + github.com/prometheus/client_golang v0.9.2 // indirect + github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 // indirect github.com/scionproto/scion v0.0.0-00010101000000-000000000000 github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337 // indirect - github.com/sony/gobreaker v0.0.0-20190329013020-a9b2a3fc7395 - github.com/stretchr/testify v1.3.0 + github.com/sony/gobreaker v0.4.1 + github.com/stretchr/testify v1.6.1 github.com/tinylib/msgp v1.1.0 // indirect - golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 - golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c - golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82 // indirect - golang.org/x/text v0.3.2 - google.golang.org/appengine v1.5.0 // indirect + golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 + golang.org/x/net v0.0.0-20200602114024-627f9648deb9 + golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect + golang.org/x/text v0.3.3 + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/d4l3k/messagediff.v1 v1.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect gopkg.in/restruct.v1 v1.0.0-20190323193435-3c2afb705f3c // indirect - gopkg.in/yaml.v2 v2.2.2 + gopkg.in/yaml.v2 v2.3.0 + gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c // indirect zombiezen.com/go/capnproto2 v2.17.0+incompatible // indirect ) diff --git a/go.sum b/go.sum index 5fb33f652..50448508c 100644 --- a/go.sum +++ b/go.sum @@ -1,29 +1,19 @@ -github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9 h1:HD8gA2tkByhMAwYaFAX9w2l7vxvBQ5NMoxDrkhqhtn4= -github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/DATA-DOG/go-sqlmock v0.0.0-20190322142548-ef0bdf231ae3 h1:ZVutvTniOCWoGqHAsZcMrw1S7TRtwZXo8ihsw/d9mD0= -github.com/DATA-DOG/go-sqlmock v0.0.0-20190322142548-ef0bdf231ae3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= -github.com/Masterminds/squirrel v0.0.0-20190319150415-55303df43ec0 h1:KneV+BokSFaPEYWe588oUu3XJL4PVA954UYHODgMrUY= -github.com/Masterminds/squirrel v0.0.0-20190319150415-55303df43ec0/go.mod h1:yaPeOnPG5ZRwL9oKdTsO/prlkPbXWZlRVMQ/gGlzIuA= +github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08= +github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/Masterminds/squirrel v1.1.0 h1:baP1qLdoQCeTw3ifCdOq2dkYc6vGcmRdaociKLbEJXs= +github.com/Masterminds/squirrel v1.1.0/go.mod h1:yaPeOnPG5ZRwL9oKdTsO/prlkPbXWZlRVMQ/gGlzIuA= github.com/antlr/antlr4 v0.0.0-20191011202612-ad2bd05285ca h1:QHbltbNkVcw97h4zA/L8gA4o3dJiFvBZ0gyZHrYXHbs= github.com/antlr/antlr4 v0.0.0-20191011202612-ad2bd05285ca/go.mod h1:T7PbCXFs94rrTttyxjbyT5+/1V8T2TYDejxUfHJjw1Y= -github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= -github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878 h1:EFSB7Zo9Eg91v7MJPVsifUysc/wPdN+NOnVe6bWbdBM= -github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878/go.mod h1:3AMJUQhVx52RsWOnlkpikZr01T/yAVN2gn0861vByNg= -github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/bclicn/color v0.0.0-20180711051946-108f2023dc84 h1:cutFptzj+ospnc1PETUqcSVTH3VQ44Bi0rpt3nE9gvo= github.com/bclicn/color v0.0.0-20180711051946-108f2023dc84/go.mod h1:Va9ap1qxjAWkIVaW1E9rH0aNgE8SDI5A4n8Ds8P0fAA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/britram/borat v0.0.0-20181011130314-f891bcfcfb9b h1:eOJHzrH26TPsYqtMlhcRV5NZKwI7iopaFbYwhd03CjA= github.com/britram/borat v0.0.0-20181011130314-f891bcfcfb9b/go.mod h1:iEd9IJ9SwedxB5kO5ypZMVq7PUNDW5lhQy92rbWBLGk= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= -github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= -github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/d4l3k/messagediff v1.2.1 h1:ZcAIMYsUg0EAp9X+tt8/enBE/Q8Yd5kzPynLyKptt9U= github.com/d4l3k/messagediff v1.2.1/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -31,55 +21,26 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dchest/cmac v0.0.0-20150527144652-62ff55a1048c h1:qoavXEzRRUfup81LsDQv4fnUQbLyorpPz6WxiwdiU7A= github.com/dchest/cmac v0.0.0-20150527144652-62ff55a1048c/go.mod h1:vWqNmss2I/DL9JKC95Lkwp+lzw+v8hwsQs7hQKyQpwk= -github.com/dgraph-io/badger v0.0.0-20190504012207-d2ebeac29495 h1:5WP6qOuO1v+irD11qR6AoOp/DYdedXKoEH2UYCvEwZY= -github.com/dgraph-io/badger v0.0.0-20190504012207-d2ebeac29495/go.mod h1:VZxzAIRPHRVNRKRo6AXrX9BJegn6il06VMTZVJYCIjQ= -github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= -github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= -github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= -github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/go-sql-driver/mysql v0.0.0-20190423112050-d0a548181995 h1:6lKNbIsmTxrYfFAmux0PTf+79tkqx83P+JKcsPplSVs= -github.com/go-sql-driver/mysql v0.0.0-20190423112050-d0a548181995/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/mock v1.4.1 h1:ocYkMQY5RrXTYgXl7ICpV0IXwlEQGwKIsery4gyXa1U= +github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v0.0.0-20190326150332-20236160a414 h1:/CWhUSWlZv5UXayDv2KZTI1+p6FxK8S/44LvGhI3WPQ= -github.com/google/btree v0.0.0-20190326150332-20236160a414/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/golang/protobuf v1.3.4 h1:87PNWwrRvUSnqS4dlcBU/ftvOIBep4sYuBLlh6rX2wk= +github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/google/gopacket v1.1.17 h1:rMrlX2ZY2UbvT+sdz3+6J+pp2z+msCq9MxTU6ymxbBY= github.com/google/gopacket v1.1.17/go.mod h1:UdDNZ1OO62aGYVnPhxT1U6aI7ukYtA/kB8vaU0diBUM= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/websocket v0.0.0-20190427040306-80c2d40e9b91 h1:6LmIJ1tEzp8HY6JlVw8FEuFN/JQtCteWXqME5jvOcQU= -github.com/gorilla/websocket v0.0.0-20190427040306-80c2d40e9b91/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= -github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-immutable-radix v1.0.0 h1:AKDB1HM5PWEA7i4nhcpwOrO2byshxBjXVn/J/3+z5/0= -github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= -github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-msgpack v0.5.5 h1:i9R9JSrqIz0QVLz3sz+i3YJdT7TTSLcfLLzJi9aZTuI= -github.com/hashicorp/go-msgpack v0.5.5/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= -github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= -github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= -github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= -github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= -github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= -github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= -github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= -github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/memberlist v0.0.0-20190312092157-a8f83c6403e0 h1:/WRUS7Gg5zYUFu9qnmbSbA9g7DQ4WIplJ/8RgeMTgko= -github.com/hashicorp/memberlist v0.0.0-20190312092157-a8f83c6403e0/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/log15 v0.0.0-20180818164646-67afb5ed74ec h1:CGkYB1Q7DSsH/ku+to+foV4agt2F2miquaLUgF6L178= @@ -88,47 +49,43 @@ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kormat/fmt15 v0.0.0-20181112140556-ee69fecb2656 h1:aG3mi6+atPavBL5PM/s0XqiRuJ2n08aEY9xza16XGTo= github.com/kormat/fmt15 v0.0.0-20181112140556-ee69fecb2656/go.mod h1:8fpYQL5jskFnAq4zE2UpspqEVHuTjurptCxHPpdoBgM= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= -github.com/lib/pq v0.0.0-20190504011754-ceb88a064902 h1:hTuFNgP3IW7lKRAWrORwxPnWgECjvdgtHHi1pL17mDw= -github.com/lib/pq v0.0.0-20190504011754-ceb88a064902/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= +github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lucas-clemente/quic-go v0.0.0-20190427152327-c135b4f1e34c h1:spQRcapCEM0Lki4wfHdLIWP1Zi0MVdQkwCeXlPqwynY= github.com/lucas-clemente/quic-go v0.0.0-20190427152327-c135b4f1e34c/go.mod h1:yhRVJZ3qR+SfGWIqXHsTAQiV1/SBE9x79oj2ekql4qk= github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI= github.com/marten-seemann/qtls v0.2.3 h1:0yWJ43C62LsZt08vuQJDK1uC1czUc3FJeCLPoNAI4vA= github.com/marten-seemann/qtls v0.2.3/go.mod h1:xzjG7avBwGGbdZ8dTGxlBnLArsVKLvwmjgmPuiQEcYk= -github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5 h1:tHXDdz1cpzGaovsTB+TVB8q90WEokoVmfMqoVcrLUgw= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.0.14 h1:9jZdLNd/P4+SfEJ0TNyxYpsK8N4GtfylBLqtbYN1sbA= -github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= -github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/netsec-ethz/netsec-scion v0.0.0-20190923155313-1f29eb7e2118 h1:jQYDCun3aJqvPMCbd19ifYiqwA/F2X/Cj1KzkQmchpc= github.com/netsec-ethz/netsec-scion v0.0.0-20190923155313-1f29eb7e2118/go.mod h1:TOIJJKLcXBqp8D+MeN6EjnXJPIL/HkBFO5fuOFK0YUI= github.com/netsec-ethz/rains v0.0.0-20190912114116-83f56a7cb2d1 h1:6ax5pk10nNDpYHDFGyp9xK4DNPxPiDqEN5iQfy7IxcY= github.com/netsec-ethz/rains v0.0.0-20190912114116-83f56a7cb2d1/go.mod h1:tfSCaRBJm3ulJOBk+sO29ttLXtYEsdEdRW/p2U5v4Yc= github.com/netsec-ethz/scion-apps v0.0.0-20191003104124-7237654083b2 h1:FN5tPcWbur8JEbpSq10HFQ4LD27ng/VwPHH8IZEMh+M= github.com/netsec-ethz/scion-apps v0.0.0-20191003104124-7237654083b2/go.mod h1:PiUidXuPbmIPFidcETIypgYjmgE7ZLF0lMxFVsAjYnw= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= -github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= @@ -139,72 +96,66 @@ github.com/pierrec/lz4 v1.0.1 h1:w6GMGWSsCI04fTM8wQRdnW74MuJISakuUU0onU0TYB4= github.com/pierrec/lz4 v1.0.1/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/xxHash v0.1.5 h1:n/jBpwTHiER4xYvK3/CdPVnLDPchj8eTJFFLUb4QHBo= github.com/pierrec/xxHash v0.1.5/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prometheus/client_golang v0.9.2 h1:awm861/B8OKDd2I/6o1dy3ra4BamzKhYOiGItCeZ740= github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.0.0-20181126121408-4724e9255275 h1:PnBWHBf+6L0jOqq0gIVUe6Yk0/QMZ640k6NvkxcBf+8= github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a h1:9a8MnZMP0X2nLJdBg+pBmGgkJlSaKC2KaQmTCk1XDtE= github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I= -github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337 h1:WN9BUFbdyOsSH/XohnWpXOlq9NBD5sGAB2FciQMUEe8= github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/sony/gobreaker v0.0.0-20190329013020-a9b2a3fc7395 h1:7pn9RucHHxJ8kVk+uKJMDbNOz0v0S5PO41bxUhxfBRE= -github.com/sony/gobreaker v0.0.0-20190329013020-a9b2a3fc7395/go.mod h1:XvpJiTD8NibaH7z0NzyfhR1+NQDtR9F/x92xheTwC9k= +github.com/sony/gobreaker v0.4.1 h1:oMnRNZXX5j85zso6xCPRNPtmAycat+WcoKbklScLDgQ= +github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU= github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= -github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= -golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 h1:p/H982KKEjUnLJkM3tt/LemDnOc1GiZL5FCVlORJ5zo= -golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c h1:uOCk1iQW6Vc18bnC13MfzScl+wdKBmM9Y9kU7Z83/lw= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82 h1:vsphBvatvfbhlb4PO1BYSr9dzugGxJ/SQHoNufZJq1w= -golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y= +golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/d4l3k/messagediff.v1 v1.2.1 h1:70AthpjunwzUiarMHyED52mj9UwtAnE89l1Gmrt3EU0= gopkg.in/d4l3k/messagediff.v1 v1.2.1/go.mod h1:EUzikiKadqXWcD1AzJLagx0j/BeeWGtn++04Xniyg44= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= @@ -216,7 +167,12 @@ gopkg.in/restruct.v1 v1.0.0-20190323193435-3c2afb705f3c/go.mod h1:WJaLhyHHEQFOgw gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c h1:grhR+C34yXImVGp7EzNk+DTIk+323eIUWOmEevy6bDo= +gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= zombiezen.com/go/capnproto2 v2.17.0+incompatible h1:sIoKPFGNlM38Qh+PBLa9Wzg1j99oInS/Qlk+5N/CHa4= zombiezen.com/go/capnproto2 v2.17.0+incompatible/go.mod h1:XO5Pr2SbXgqZwn0m0Ru54QBqpOf4K5AYBO+8LAOBQEQ= diff --git a/log/disabled.go b/log/disabled.go index 490565ece..a9d4b8d9e 100644 --- a/log/disabled.go +++ b/log/disabled.go @@ -11,7 +11,5 @@ func (*disabledLogger) Level() Level { return OffLevel } -func (*disabledLogger) Log(level Level, pkg string, file string, line int, format string, args ...interface{}) { -} - -func (*disabledLogger) Close() error { return nil } +func (*disabledLogger) Log(_ Level, _ string, _ string, _ int, _ string, _ ...interface{}) {} +func (*disabledLogger) Close() error { return nil } diff --git a/log/log.go b/log/log.go index 1ab3e7843..060e8038c 100644 --- a/log/log.go +++ b/log/log.go @@ -220,7 +220,7 @@ func (l *logger) loop() { if !ok { // close log files for _, w := range l.files { - w.Close() + _ = w.Close() } return } @@ -246,9 +246,9 @@ func (l *logger) loop() { line := l.b.String() - fmt.Fprintf(l.output, line) + _, _ = fmt.Fprintf(l.output, line) for _, w := range l.files { - fmt.Fprintf(w, line) + _, _ = fmt.Fprintf(w, line) } if rec.level == FatalLevel { exitHandler() diff --git a/model/blocklistitem.go b/model/blocklistitem.go index aedf4b8c6..658d38109 100644 --- a/model/blocklistitem.go +++ b/model/blocklistitem.go @@ -16,7 +16,7 @@ type BlockListItem struct { JID string } -// FromBytes deserializes a BlockListItem entity from it's gob binary representation. +// FromBytes deserializes a BlockListItem entity from its binary representation. func (bli *BlockListItem) FromBytes(buf *bytes.Buffer) error { dec := gob.NewDecoder(buf) if err := dec.Decode(&bli.Username); err != nil { @@ -25,8 +25,7 @@ func (bli *BlockListItem) FromBytes(buf *bytes.Buffer) error { return dec.Decode(&bli.JID) } -// ToBytes converts a BlockListItem entity -// to it's gob binary representation. +// ToBytes converts a BlockListItem entity to its binary representation. func (bli *BlockListItem) ToBytes(buf *bytes.Buffer) error { enc := gob.NewEncoder(buf) if err := enc.Encode(&bli.Username); err != nil { diff --git a/model/capabilities/capabilities.go b/model/capabilities/capabilities.go new file mode 100644 index 000000000..bcae9ee27 --- /dev/null +++ b/model/capabilities/capabilities.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package capsmodel + +import ( + "bytes" + "encoding/gob" +) + +// Capabilities represents presence capabilities info +type Capabilities struct { + Node string + Ver string + Features []string +} + +// FromBytes deserializes a Capabilities entity from its binary representation. +func (c *Capabilities) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&c.Node); err != nil { + return err + } + if err := dec.Decode(&c.Ver); err != nil { + return err + } + return dec.Decode(&c.Features) +} + +// ToBytes converts a Capabilities entity to its binary representation. +func (c *Capabilities) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := enc.Encode(&c.Node); err != nil { + return err + } + if err := enc.Encode(&c.Ver); err != nil { + return err + } + return enc.Encode(&c.Features) +} + +// HasFeature returns whether or not Capabilities instance contains a concrete feature +func (c *Capabilities) HasFeature(feature string) bool { + for _, f := range c.Features { + if f == feature { + return true + } + } + return false +} diff --git a/model/capabilities/capabilities_test.go b/model/capabilities/capabilities_test.go new file mode 100644 index 000000000..eebb64363 --- /dev/null +++ b/model/capabilities/capabilities_test.go @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package capsmodel + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCapabilities(t *testing.T) { + var c1, c2 Capabilities + c1 = Capabilities{Node: "n", Ver: "v", Features: []string{"ns1", "ns2"}} + + require.True(t, c1.HasFeature("ns2")) + require.False(t, c1.HasFeature("ns3")) + + buf := new(bytes.Buffer) + require.Nil(t, c1.ToBytes(buf)) + require.Nil(t, c2.FromBytes(buf)) + require.Equal(t, c1, c2) +} diff --git a/model/capabilities/presence_caps.go b/model/capabilities/presence_caps.go new file mode 100644 index 000000000..f85a6f740 --- /dev/null +++ b/model/capabilities/presence_caps.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package capsmodel + +import ( + "bytes" + "encoding/gob" + + "github.com/ortuman/jackal/xmpp" +) + +// PresenceCaps represents the combination of along with its capabilities. +type PresenceCaps struct { + Presence *xmpp.Presence + Caps *Capabilities +} + +// FromBytes deserializes a Capabilities entity from its binary representation. +func (p *PresenceCaps) FromBytes(buf *bytes.Buffer) error { + presence, err := xmpp.NewPresenceFromBytes(buf) + if err != nil { + return err + } + var hasCaps bool + + dec := gob.NewDecoder(buf) + if err := dec.Decode(&hasCaps); err != nil { + return err + } + p.Presence = presence + if hasCaps { + return dec.Decode(&p.Caps) + } + return nil +} + +// ToBytes converts a Capabilities entity to its binary representation. +func (p *PresenceCaps) ToBytes(buf *bytes.Buffer) error { + if err := p.Presence.ToBytes(buf); err != nil { + return err + } + enc := gob.NewEncoder(buf) + + hasCaps := p.Caps != nil + if err := enc.Encode(hasCaps); err != nil { + return err + } + if p.Caps != nil { + if err := enc.Encode(p.Caps); err != nil { + return err + } + } + return nil +} diff --git a/model/capabilities/presence_caps_test.go b/model/capabilities/presence_caps_test.go new file mode 100644 index 000000000..c627f8452 --- /dev/null +++ b/model/capabilities/presence_caps_test.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package capsmodel + +import ( + "bytes" + "testing" + + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestPresenceCapabilities(t *testing.T) { + j1, _ := jid.NewWithString("ortuman@jackal.im", true) + + var p1, p2 PresenceCaps + p1 = PresenceCaps{ + Presence: xmpp.NewPresence(j1, j1, xmpp.AvailableType), + } + + buf := new(bytes.Buffer) + require.Nil(t, p1.ToBytes(buf)) + require.Nil(t, p2.FromBytes(buf)) + require.Equal(t, p1, p2) + + var p3, p4 PresenceCaps + p3 = PresenceCaps{ + Presence: xmpp.NewPresence(j1, j1, xmpp.AvailableType), + Caps: &Capabilities{ + Node: "http://jackal.im", + Ver: "v1234", + }, + } + buf = new(bytes.Buffer) + require.Nil(t, p3.ToBytes(buf)) + require.Nil(t, p4.FromBytes(buf)) + require.Equal(t, p3, p4) +} diff --git a/model/muc/config.go b/model/muc/config.go new file mode 100644 index 000000000..490d8dc60 --- /dev/null +++ b/model/muc/config.go @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mucmodel + +import ( + "bytes" + "encoding/gob" + "fmt" +) + +// canSendPM and canGetMemberList values +const ( + All = "all" + Moderators = "moderators" + None = "" +) + +// RoomConfig represents different room types +type RoomConfig struct { + Public bool + Persistent bool + PwdProtected bool + Password string + Open bool + Moderated bool + AllowInvites bool + MaxOccCnt int + AllowSubjChange bool + NonAnonymous bool + canSendPM string + canGetMemberList string +} + +type roomConfigProxy struct { + Public bool `yaml:public` + Persistent bool `yaml:persistent` + PwdProtected bool `yaml:password_protected` + Open bool `yaml:"open"` + Moderated bool `yaml:"moderated"` + AllowInvites bool `yaml:"allow_invites"` + MaxOccCnt int `yaml:"occupant_count"` + NonAnonymous bool `yaml:"non_anonymous"` + CanSendPM string `yaml:"send_pm"` + CanGetMemberList string `yaml:"can_get_member_list"` + AllowSubjChange bool `yaml:"allow_subject_change"` +} + +// FromBytes deserializes a RoomConfig entity from it's gob binary representation. +func (r *RoomConfig) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&r.Public); err != nil { + return err + } + if err := dec.Decode(&r.Persistent); err != nil { + return err + } + if err := dec.Decode(&r.PwdProtected); err != nil { + return err + } + if r.PwdProtected { + if err := dec.Decode(&r.Password); err != nil { + return err + } + } + if err := dec.Decode(&r.Open); err != nil { + return err + } + if err := dec.Decode(&r.Moderated); err != nil { + return err + } + if err := dec.Decode(&r.NonAnonymous); err != nil { + return err + } + if err := dec.Decode(&r.canSendPM); err != nil { + return err + } + if err := dec.Decode(&r.AllowInvites); err != nil { + return err + } + if err := dec.Decode(&r.AllowSubjChange); err != nil { + return err + } + if err := dec.Decode(&r.canGetMemberList); err != nil { + return err + } + if err := dec.Decode(&r.MaxOccCnt); err != nil { + return err + } + return nil +} + +// ToBytes converts a RoomConfig entity to it's gob binary representation. +func (r *RoomConfig) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := enc.Encode(&r.Public); err != nil { + return err + } + if err := enc.Encode(&r.Persistent); err != nil { + return err + } + if err := enc.Encode(&r.PwdProtected); err != nil { + return err + } + if r.PwdProtected { + if err := enc.Encode(&r.Password); err != nil { + return err + } + } + if err := enc.Encode(&r.Open); err != nil { + return err + } + if err := enc.Encode(&r.Moderated); err != nil { + return err + } + if err := enc.Encode(&r.NonAnonymous); err != nil { + return err + } + if err := enc.Encode(&r.canSendPM); err != nil { + return err + } + if err := enc.Encode(&r.AllowInvites); err != nil { + return err + } + if err := enc.Encode(&r.AllowSubjChange); err != nil { + return err + } + if err := enc.Encode(&r.canGetMemberList); err != nil { + return err + } + if err := enc.Encode(&r.MaxOccCnt); err != nil { + return err + } + return nil +} + +// NewConfigFromBytes creates and returns a new RoomConfig element from its bytes representation. +func NewConfigFromBytes(buf *bytes.Buffer) (*RoomConfig, error) { + c := &RoomConfig{} + if err := c.FromBytes(buf); err != nil { + return nil, err + } + return c, nil +} + +// UnmarshalYAML satisfies Unmarshaler interface, sets the default room type for the MUC service +func (r *RoomConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + p := roomConfigProxy{} + if err := unmarshal(&p); err != nil { + return err + } + r.Public = p.Public + r.Persistent = p.Persistent + r.PwdProtected = p.PwdProtected + r.Open = p.Open + r.Moderated = p.Moderated + r.AllowInvites = p.AllowInvites + r.MaxOccCnt = p.MaxOccCnt + r.AllowSubjChange = p.AllowSubjChange + r.NonAnonymous = p.NonAnonymous + err := r.SetWhoCanSendPM(p.CanSendPM) + if err != nil { + return err + } + err = r.SetWhoCanGetMemberList(p.CanGetMemberList) + if err != nil { + return err + } + return nil +} + +func (r *RoomConfig) SetWhoCanSendPM(s string) error { + switch s { + case All, Moderators, None: + r.canSendPM = s + default: + return fmt.Errorf("muc_config: cannot set who can send private messages to %s", s) + } + return nil +} + +func (r *RoomConfig) WhoCanSendPM() string { + return r.canSendPM +} + +func (r *RoomConfig) OccupantCanSendPM(o *Occupant) bool { + var hasPermission bool + switch r.canSendPM { + case All: + hasPermission = true + case None: + hasPermission = false + case Moderators: + hasPermission = o.IsModerator() + } + return hasPermission +} + +func (r *RoomConfig) SetWhoCanGetMemberList(s string) error { + switch s { + case All, Moderators, None: + r.canGetMemberList = s + default: + return fmt.Errorf("muc_config: cannot set who can get member list to %s", s) + } + return nil +} + +func (r *RoomConfig) WhoCanGetMemberList() string { + return r.canGetMemberList +} + +func (r *RoomConfig) OccupantCanGetMemberList(o *Occupant) bool { + var hasPermission bool + switch r.canGetMemberList { + case All: + hasPermission = true + case None: + hasPermission = false + case Moderators: + hasPermission = o.IsModerator() + } + return hasPermission +} + +func (r *RoomConfig) OccupantCanDiscoverRealJID(o *Occupant) bool { + if r.NonAnonymous { + return true + } + return o.IsModerator() +} + +func (r *RoomConfig) OccupantCanChangeSubject(o *Occupant) bool { + if r.AllowSubjChange { + return true + } + return o.IsModerator() +} diff --git a/model/muc/config_test.go b/model/muc/config_test.go new file mode 100644 index 000000000..9336896fd --- /dev/null +++ b/model/muc/config_test.go @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mucmodel + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +const cfgExample = ` +public: true +persistent: true +password_protected: false +moderated: false +allow_invites: false +allow_subject_change: true +enable_logging: true +history_length: 20 +occupant_count: -1 +non_anonymous: true +send_pm: "moderators" +can_get_member_list: "" +` + +func TestRoomConfig_Bytes(t *testing.T) { + rc1 := RoomConfig{ + Public: true, + Persistent: true, + PwdProtected: true, + Password: "pwd", + Open: true, + Moderated: true, + AllowInvites: true, + MaxOccCnt: 20, + AllowSubjChange: false, + NonAnonymous: true, + canSendPM: "", + canGetMemberList: "moderators", + } + + buf := new(bytes.Buffer) + require.Nil(t, rc1.ToBytes(buf)) + + rc2 := RoomConfig{} + require.Nil(t, rc2.FromBytes(buf)) + + assert.EqualValues(t, rc1, rc2) +} + +func TestRoomConfig_UnmarshalYaml(t *testing.T) { + badCfg := `public: "public"` + cfg := &RoomConfig{} + err := yaml.Unmarshal([]byte(badCfg), &cfg) + require.NotNil(t, err) + + goodCfg := cfgExample + cfg = &RoomConfig{} + err = yaml.Unmarshal([]byte(goodCfg), &cfg) + require.Nil(t, err) + require.True(t, cfg.Public) + require.False(t, cfg.PwdProtected) + require.False(t, cfg.Open) + require.True(t, cfg.NonAnonymous) + require.Equal(t, cfg.WhoCanSendPM(), Moderators) +} + +func TestRoomConfig_PrivateFields(t *testing.T) { + cfg := &RoomConfig{} + err := cfg.SetWhoCanSendPM("fail") + require.NotNil(t, err) + err = cfg.SetWhoCanSendPM(Moderators) + require.Nil(t, err) + require.Equal(t, Moderators, cfg.WhoCanSendPM()) + + err = cfg.SetWhoCanGetMemberList("fail") + require.NotNil(t, err) + err = cfg.SetWhoCanGetMemberList(None) + require.Nil(t, err) + require.Equal(t, None, cfg.WhoCanGetMemberList()) +} + +func TestRoomConfig_OccupantPermissions(t *testing.T) { + cfg := &RoomConfig{ + canSendPM: "", + canGetMemberList: "moderators", + NonAnonymous: false, + AllowSubjChange: true, + } + o := &Occupant{ + role: moderator, + } + require.False(t, cfg.OccupantCanSendPM(o)) + require.True(t, cfg.OccupantCanGetMemberList(o)) + require.True(t, cfg.OccupantCanDiscoverRealJID(o)) + require.True(t, cfg.OccupantCanChangeSubject(o)) +} diff --git a/model/muc/occupant.go b/model/muc/occupant.go new file mode 100644 index 000000000..30a8149e3 --- /dev/null +++ b/model/muc/occupant.go @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mucmodel + +import ( + "bytes" + "encoding/gob" + "fmt" + + "github.com/ortuman/jackal/xmpp/jid" +) + +const ( + // Affiliations + member = "member" + admin = "admin" + owner = "owner" + outcast = "outcast" + + // Roles + moderator = "moderator" + participant = "participant" + visitor = "visitor" + none = "none" +) + +// Occupant represents a single user in a MUC room (XEP-0045) +type Occupant struct { + OccupantJID *jid.JID + BareJID *jid.JID + affiliation string + role string + // a set of different resources that the user uses to access this occupant + resources map[string]bool +} + +// FromBytes deserializes an Occupant entity from it's gob binary representation. +func (o *Occupant) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + j, err := jid.NewFromBytes(buf) + if err != nil { + return err + } + o.OccupantJID = j + f, err := jid.NewFromBytes(buf) + if err != nil { + return err + } + o.BareJID = f + if err := dec.Decode(&o.affiliation); err != nil { + return err + } + if err := dec.Decode(&o.role); err != nil { + return err + } + var numResources int + if err := dec.Decode(&numResources); err != nil { + return err + } + o.resources = make(map[string]bool) + for i := 0; i < numResources; i++ { + var res string + if err := dec.Decode(&res); err != nil { + return err + } + o.resources[res] = true + } + return nil +} + +// ToBytes converts an Occupant entity to it's gob binary representation. +func (o *Occupant) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := o.OccupantJID.ToBytes(buf); err != nil { + return err + } + if err := o.BareJID.ToBytes(buf); err != nil { + return err + } + if err := enc.Encode(&o.affiliation); err != nil { + return err + } + if err := enc.Encode(&o.role); err != nil { + return err + } + if err := enc.Encode(len(o.resources)); err != nil { + return err + } + for res, _ := range o.resources { + if err := enc.Encode(&res); err != nil { + return err + } + } + return nil +} + +// NewOccupantFromBytes creates and returns a new Occupant element from its bytes representation. +func NewOccupantFromBytes(buf *bytes.Buffer) (*Occupant, error) { + o := &Occupant{} + if err := o.FromBytes(buf); err != nil { + return nil, err + } + return o, nil +} + +// NewOccupant creates and return a new Occupant element given its occupant and user JIDs +func NewOccupant(occJID, userJID *jid.JID) (*Occupant, error) { + if !occJID.IsFullWithUser() { + return nil, fmt.Errorf("Occupant JID %s is not valid", occJID.String()) + } + if !userJID.IsBare() { + return nil, fmt.Errorf("User JID %s is not a bare JID", userJID.String()) + } + o := &Occupant{ + OccupantJID: occJID, + BareJID: userJID, + } + o.resources = make(map[string]bool) + return o, nil +} + +func (o *Occupant) SetAffiliation(aff string) error { + switch aff { + case owner, admin, member, outcast, None: + o.affiliation = aff + default: + return fmt.Errorf("occupant: this type of affiliation is not supported - %s", aff) + } + return nil +} + +func (o *Occupant) GetAffiliation() string { + return o.affiliation +} + +func (o *Occupant) SetRole(role string) error { + switch role { + case moderator, participant, visitor, None: + o.role = role + default: + return fmt.Errorf("occupant: this type of role is not supported - %s", role) + } + return nil +} + +func (o *Occupant) GetRole() string { + return o.role +} + +func (o *Occupant) HasNoRole() bool { + return o.role == None +} + +func (o *Occupant) IsVisitor() bool { + return o.role == visitor +} + +func (o *Occupant) IsParticipant() bool { + return o.role == participant +} + +func (o *Occupant) IsModerator() bool { + return o.role == moderator +} + +func (o *Occupant) HasNoAffiliation() bool { + return o.affiliation == None +} + +func (o *Occupant) IsOwner() bool { + return o.affiliation == owner +} + +func (o *Occupant) IsAdmin() bool { + return o.affiliation == admin +} + +func (o *Occupant) IsMember() bool { + return o.affiliation == member +} + +func (o *Occupant) IsOutcast() bool { + return o.affiliation == outcast +} + +// GetAllResources returns the list of resources that user accessed this occupant with +func (o *Occupant) GetAllResources() []string { + resources := make([]string, 0, len(o.resources)) + for r := range o.resources { + resources = append(resources, r) + } + return resources +} + +func (o *Occupant) HasResource(s string) bool { + _, found := o.resources[s] + return found +} + +func (o *Occupant) AddResource(s string) { + o.resources[s] = true +} + +func (o *Occupant) DeleteResource(s string) { + delete(o.resources, s) +} + +func (o *Occupant) HasHigherAffiliation(k *Occupant) bool { + switch { + case o.IsOwner(): + return true + case o.IsAdmin(): + return !k.IsOwner() + case o.IsMember(): + return !k.IsOwner() && !k.IsAdmin() + case o.HasNoAffiliation(): + return k.HasNoAffiliation() + } + return false +} + +func (o *Occupant) CanChangeRole(target *Occupant, role string) bool { + switch role { + case none: + return o.IsModerator() && o.HasHigherAffiliation(target) || o.IsOwner() + case visitor: + return o.IsModerator() && target.IsParticipant() || o.IsOwner() + case participant: + return o.IsModerator() && target.IsVisitor() || o.IsAdmin() && !target.IsOwner() || + o.IsOwner() + case moderator: + return o.IsAdmin() || o.IsOwner() + } + return false +} + +func (o *Occupant) CanChangeAffiliation(target *Occupant, affiliation string) bool { + // not allowed to change your own affiliation + if o.OccupantJID.String() == target.OccupantJID.String() { + return false + } + // only admins and owners can change affiliations + if !o.IsAdmin() && !o.IsOwner() { + return false + } + switch affiliation { + case none: + return o.HasHigherAffiliation(target) + case member: + return o.HasHigherAffiliation(target) + case admin: + return o.IsOwner() + case owner: + return o.IsOwner() + } + return false +} diff --git a/model/muc/occupant_test.go b/model/muc/occupant_test.go new file mode 100644 index 000000000..70250d81d --- /dev/null +++ b/model/muc/occupant_test.go @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mucmodel + +import ( + "bytes" + "testing" + + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOccupant_Bytes(t *testing.T) { + jOcc, _ := jid.NewWithString("room@conference.jackal.im/mynick", true) + jFull, _ := jid.NewWithString("ortuman@jackal.im/laptop", true) + o1 := Occupant{ + OccupantJID: jOcc, + BareJID: jFull.ToBareJID(), + affiliation: "owner", + role: "moderator", + resources: make(map[string]bool), + } + o1.resources[jFull.Resource()] = true + + buf := new(bytes.Buffer) + require.Nil(t, o1.ToBytes(buf)) + + o2 := Occupant{} + require.Nil(t, o2.FromBytes(buf)) + + assert.EqualValues(t, o1, o2) +} + +func TestOccupant_RoleAndAffiliation(t *testing.T) { + jo, _ := jid.NewWithString("room@conference.jackal.im/owner", true) + o := &Occupant{ + OccupantJID: jo, + affiliation: "", + role: "visitor", + } + + require.False(t, o.IsOwner()) + require.False(t, o.IsAdmin()) + require.False(t, o.IsMember()) + require.False(t, o.IsOutcast()) + require.True(t, o.HasNoAffiliation()) + + require.True(t, o.IsVisitor()) + require.False(t, o.IsParticipant()) + require.False(t, o.IsModerator()) + + err := o.SetAffiliation("fail") + require.NotNil(t, err) + err = o.SetAffiliation("owner") + require.Nil(t, err) + + err = o.SetRole("fail") + require.NotNil(t, err) + err = o.SetRole(moderator) + require.True(t, o.IsModerator()) + + jo2, _ := jid.NewWithString("room@conference.jackal.im/admin", true) + o2 := &Occupant{ + OccupantJID: jo2, + affiliation: "admin", + role: "moderator", + } + + require.True(t, o.HasHigherAffiliation(o2)) + require.False(t, o2.HasHigherAffiliation(o)) + require.False(t, o.CanChangeRole(o2, "fail")) + require.True(t, o.CanChangeRole(o2, "visitor")) + require.True(t, o.CanChangeAffiliation(o2, "owner")) + require.False(t, o.CanChangeAffiliation(o2, "fail")) + require.False(t, o2.CanChangeAffiliation(o, "admin")) +} + +func TestOccupant_Resources(t *testing.T) { + jOcc, _ := jid.NewWithString("room@conference.jackal.im/mynick", true) + jFull, _ := jid.NewWithString("ortuman@jackal.im/laptop", true) + + o, err := NewOccupant(jOcc, jFull) + require.NotNil(t, err) + require.Nil(t, o) + o, err = NewOccupant(jOcc, jFull.ToBareJID()) + require.Nil(t, err) + require.NotNil(t, o) + + require.False(t, o.HasResource("laptop")) + o.AddResource("laptop") + require.True(t, o.HasResource("laptop")) + require.Len(t, o.GetAllResources(), 1) + o.DeleteResource("laptop") + require.False(t, o.HasResource("laptop")) +} diff --git a/model/muc/room.go b/model/muc/room.go new file mode 100644 index 000000000..fe1d763b9 --- /dev/null +++ b/model/muc/room.go @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mucmodel + +import ( + "bytes" + "encoding/gob" + "fmt" + + "github.com/ortuman/jackal/log" + "github.com/ortuman/jackal/xmpp/jid" +) + +// Room represents a Multi-User Chat Room entity (XEP-0045) +type Room struct { + Config *RoomConfig + RoomJID *jid.JID + Name string + Desc string + Subject string + Language string + Locked bool + //mapping user bare jid to the occupant JID + userToOccupant map[jid.JID]jid.JID + // a set of invited users' bare JIDs who haven't accepted the invitation yet + invitedUsers map[jid.JID]bool + occupantsOnline int +} + +// FromBytes deserializes a Room entity from it's gob binary representation. +func (r *Room) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&r.Name); err != nil { + return err + } + j, err := jid.NewFromBytes(buf) + if err != nil { + return err + } + r.RoomJID = j + if err := dec.Decode(&r.Desc); err != nil { + return err + } + if err := dec.Decode(&r.Subject); err != nil { + return err + } + if err := dec.Decode(&r.Language); err != nil { + return err + } + c, err := NewConfigFromBytes(buf) + if err != nil { + return err + } + r.Config = c + var numberOfOccupants int + if err := dec.Decode(&numberOfOccupants); err != nil { + return err + } + r.userToOccupant = make(map[jid.JID]jid.JID) + for i := 0; i < numberOfOccupants; i++ { + userJID, err := jid.NewFromBytes(buf) + if err != nil { + return err + } + occJID, err := jid.NewFromBytes(buf) + if err != nil { + return err + } + r.userToOccupant[*userJID] = *occJID + } + if err := dec.Decode(&r.Locked); err != nil { + return err + } + var invitedUsersCount int + if err := dec.Decode(&invitedUsersCount); err != nil { + return err + } + r.invitedUsers = make(map[jid.JID]bool) + for i := 0; i < invitedUsersCount; i++ { + userJID, err := jid.NewFromBytes(buf) + if err != nil { + return err + } + r.invitedUsers[*userJID] = true + } + if err := dec.Decode(&r.occupantsOnline); err != nil { + return err + } + return nil +} + +// ToBytes converts a Room entity to it's gob binary representation. +func (r *Room) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := enc.Encode(&r.Name); err != nil { + return err + } + if err := r.RoomJID.ToBytes(buf); err != nil { + return err + } + if err := enc.Encode(&r.Desc); err != nil { + return err + } + if err := enc.Encode(&r.Subject); err != nil { + return err + } + if err := enc.Encode(&r.Language); err != nil { + return err + } + if err := r.Config.ToBytes(buf); err != nil { + return err + } + if err := enc.Encode(len(r.userToOccupant)); err != nil { + return err + } + for userJID, occJID := range r.userToOccupant { + if err := userJID.ToBytes(buf); err != nil { + return err + } + if err := occJID.ToBytes(buf); err != nil { + return err + } + } + if err := enc.Encode(&r.Locked); err != nil { + return err + } + if err := enc.Encode(len(r.invitedUsers)); err != nil { + return err + } + for userJID, _ := range r.invitedUsers { + if err := userJID.ToBytes(buf); err != nil { + return err + } + } + if err := enc.Encode(&r.occupantsOnline); err != nil { + return err + } + return nil +} + +func (r *Room) AddOccupant(o *Occupant) { + // if this user was invited, remove from the list of invited users + if r.UserIsInvited(o.BareJID.ToBareJID()) { + o.SetAffiliation("member") + r.DeleteInvite(o.BareJID.ToBareJID()) + } + + err := r.MapUserToOccupantJID(o.BareJID, o.OccupantJID) + if err != nil { + log.Error(err) + return + } + + if o.HasNoRole() { + r.SetDefaultRole(o) + } + + r.occupantsOnline++ +} + +func (r *Room) OccupantLeft(o *Occupant) { + // occupants with no affiliation are deleted once they leave the room + if o.HasNoAffiliation() { + delete(r.userToOccupant, *o.BareJID) + } + r.occupantsOnline-- +} + +func (r *Room) SetDefaultRole(o *Occupant) { + if o.IsOwner() || o.IsAdmin() { + o.SetRole(moderator) + } else if r.Config.Moderated && o.HasNoAffiliation() { + o.SetRole(visitor) + } else { + o.SetRole(participant) + } +} + +// MapUserToOccupantJID adds the mapping between bare user JID and occupant JID +func (r *Room) MapUserToOccupantJID(userJID, occJID *jid.JID) error { + if !occJID.IsFullWithUser() { + return fmt.Errorf("Occupant JID %s is not valid", occJID.String()) + } + if !userJID.IsBare() { + return fmt.Errorf("User JID %s is not a bare JID", userJID.String()) + } + + // if this is the first occupant in the room, create the map + if r.userToOccupant == nil { + r.userToOccupant = make(map[jid.JID]jid.JID) + } + + // only one occupant JID per user is allowed + _, found := r.userToOccupant[*userJID] + if !found { + r.userToOccupant[*userJID] = *occJID + } + + return nil +} + +func (r *Room) UserIsInRoom(userJID *jid.JID) bool { + _, found := r.userToOccupant[*userJID] + return found +} + +// GetOccupantJID returns the occupant JID of the user with provided user JID +func (r *Room) GetOccupantJID(userJID *jid.JID) (jid.JID, bool) { + occJID, found := r.userToOccupant[*userJID] + return occJID, found +} + +// GetAllOccupantJIDs returns slice of occupant JIDs of everyone in the room +func (r *Room) GetAllOccupantJIDs() []jid.JID { + res := make([]jid.JID, 0, len(r.userToOccupant)) + for _, occJID := range r.userToOccupant { + res = append(res, occJID) + } + return res +} + +// GetAllOccupantJIDs returns slice of user JIDs of everyone in the room +func (r *Room) GetAllUserJIDs() []jid.JID { + res := make([]jid.JID, 0, len(r.userToOccupant)) + for usrJID, _ := range r.userToOccupant { + res = append(res, usrJID) + } + return res +} + +// InviteUser adds the user JID into the set of invited users +func (r *Room) InviteUser(userJID *jid.JID) error { + if !userJID.IsBare() { + return fmt.Errorf("User JID %s is not a bare JID", userJID) + } + + if r.invitedUsers == nil { + r.invitedUsers = make(map[jid.JID]bool) + } + + r.invitedUsers[*userJID] = true + return nil +} + +func (r *Room) UserIsInvited(userJID *jid.JID) bool { + // if no one is invited, return false + if r.invitedUsers == nil { + return false + } + + _, invited := r.invitedUsers[*userJID] + return invited +} + +func (r *Room) DeleteInvite(userJID *jid.JID) { + delete(r.invitedUsers, *userJID) +} + +func (r *Room) IsFull() bool { + // MaxOccCnt = -1 used for the rooms with unlimited capacity + if r.Config.MaxOccCnt == -1 { + return false + } + return r.occupantsOnline >= r.Config.MaxOccCnt +} + +func (r *Room) IsEmpty() bool { + return r.occupantsOnline == 0 +} + +func (r *Room) GetOccupantsOnlineCount() int { + return r.occupantsOnline +} + +func (r *Room) SetOccupantsOnlineCount(i int) { + r.occupantsOnline = i +} + +// GetAllInvitedUsers returns slice of user JIDs of everyone invited into the room +func (r *Room) GetAllInvitedUsers() []string { + res := make([]string, 0, len(r.invitedUsers)) + for jid := range r.invitedUsers { + res = append(res, jid.String()) + } + return res +} diff --git a/model/muc/room_test.go b/model/muc/room_test.go new file mode 100644 index 000000000..cea62bad8 --- /dev/null +++ b/model/muc/room_test.go @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mucmodel + +import ( + "bytes" + "testing" + + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRoom_Bytes(t *testing.T) { + r1 := getTestRoom() + r1.userToOccupant = make(map[jid.JID]jid.JID) + r1.invitedUsers = make(map[jid.JID]bool) + buf := new(bytes.Buffer) + require.Nil(t, r1.ToBytes(buf)) + + r2 := &Room{} + require.Nil(t, r2.FromBytes(buf)) + assert.EqualValues(t, r1, r2) +} + +func TestRoom_Occupants(t *testing.T) { + room := getTestRoom() + userJID, _ := jid.NewWithString("ortuman@jackal.im/balcony", true) + occJID, _ := jid.NewWithString("testroom@conference.jackal.im/nick", true) + o := &Occupant{ + OccupantJID: occJID, + BareJID: userJID.ToBareJID(), + affiliation: "member", + } + + room.AddOccupant(o) + require.True(t, o.IsParticipant()) + require.True(t, room.UserIsInRoom(userJID.ToBareJID())) + require.Equal(t, room.occupantsOnline, 1) + resJID, inRoom := room.GetOccupantJID(userJID.ToBareJID()) + require.Equal(t, resJID.String(), occJID.String()) + require.True(t, inRoom) + room.OccupantLeft(o) + require.True(t, room.UserIsInRoom(userJID.ToBareJID())) + require.Equal(t, room.occupantsOnline, 0) +} + +func TestRoom_Invites(t *testing.T) { + room := getTestRoom() + userJID, _ := jid.NewWithString("ortuman@jackal.im/balcony", true) + + require.False(t, room.UserIsInvited(userJID.ToBareJID())) + err := room.InviteUser(userJID.ToBareJID()) + require.Nil(t, err) + require.True(t, room.UserIsInvited(userJID.ToBareJID())) + room.DeleteInvite(userJID.ToBareJID()) + require.False(t, room.UserIsInvited(userJID.ToBareJID())) +} + +func getTestRoom() *Room { + rc := RoomConfig{ + Public: true, + Persistent: true, + PwdProtected: false, + Open: true, + Moderated: true, + } + j, _ := jid.NewWithString("testroom@conference.jackal.im", true) + return &Room{ + Name: "testRoom", + RoomJID: j, + Desc: "Room for Testing", + Config: &rc, + Locked: false, + } +} diff --git a/model/pubsub/affiliation.go b/model/pubsub/affiliation.go new file mode 100644 index 000000000..db1b4e7d4 --- /dev/null +++ b/model/pubsub/affiliation.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "bytes" + "encoding/gob" +) + +// affiliation definitions +const ( + Owner = "owner" + Member = "member" + Publisher = "publisher" + Outcast = "outcast" +) + +// subscription definitions +const ( + None = "none" + Subscribed = "subscribed" +) + +// Affiliation represents a pubsub affiliation +type Affiliation struct { + JID string + Affiliation string +} + +// FromBytes deserializes a Affiliation entity from its binary representation. +func (a *Affiliation) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&a.JID); err != nil { + return err + } + return dec.Decode(&a.Affiliation) +} + +// ToBytes converts a Affiliation entity to its binary representation. +func (a *Affiliation) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := enc.Encode(a.JID); err != nil { + return err + } + return enc.Encode(a.Affiliation) +} diff --git a/model/pubsub/affiliation_test.go b/model/pubsub/affiliation_test.go new file mode 100644 index 000000000..85a712070 --- /dev/null +++ b/model/pubsub/affiliation_test.go @@ -0,0 +1,23 @@ +package pubsubmodel + +import ( + "bytes" + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAffiliation_Serialize(t *testing.T) { + a := Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: "owner", + } + b := bytes.NewBuffer(nil) + require.Nil(t, a.ToBytes(b)) + + var a2 Affiliation + require.Nil(t, a2.FromBytes(b)) + + require.True(t, reflect.DeepEqual(a, a2)) +} diff --git a/model/pubsub/item.go b/model/pubsub/item.go new file mode 100644 index 000000000..5201480f5 --- /dev/null +++ b/model/pubsub/item.go @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "bytes" + "encoding/gob" + + "github.com/ortuman/jackal/xmpp" +) + +// Item represents a pubsub node item +type Item struct { + ID string + Publisher string + Payload xmpp.XElement +} + +// FromBytes deserializes a Item entity from its binary representation. +func (i *Item) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&i.ID); err != nil { + return err + } + if err := dec.Decode(&i.Publisher); err != nil { + return err + } + var hasPayload bool + if err := dec.Decode(&hasPayload); err != nil { + return err + } + if hasPayload { + var elem xmpp.Element + if err := elem.FromBytes(buf); err != nil { + return err + } + i.Payload = &elem + } + return nil +} + +// ToBytes converts a Item entity to its binary representation. +func (i *Item) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := enc.Encode(i.ID); err != nil { + return err + } + if err := enc.Encode(i.Publisher); err != nil { + return err + } + hasPayload := i.Payload != nil + if err := enc.Encode(hasPayload); err != nil { + return err + } + if i.Payload != nil { + return i.Payload.ToBytes(buf) + } + return nil +} diff --git a/model/pubsub/item_test.go b/model/pubsub/item_test.go new file mode 100644 index 000000000..87aa328eb --- /dev/null +++ b/model/pubsub/item_test.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "bytes" + "reflect" + "testing" + + "github.com/ortuman/jackal/xmpp" + "github.com/stretchr/testify/require" +) + +func TestItem_Serialization(t *testing.T) { + it := Item{} + it.ID = "1234" + it.Publisher = "ortuman@jackal.im" + it.Payload = xmpp.NewElementName("el") + + buf := bytes.NewBuffer(nil) + require.Nil(t, it.ToBytes(buf)) + + it2 := Item{} + _ = it2.FromBytes(buf) + + require.True(t, reflect.DeepEqual(&it, &it2)) + + // nil payload + it.Payload = nil + + buf2 := bytes.NewBuffer(nil) + require.Nil(t, it.ToBytes(buf2)) + + it3 := Item{} + _ = it3.FromBytes(buf2) + + require.True(t, reflect.DeepEqual(&it, &it3)) +} diff --git a/model/pubsub/node.go b/model/pubsub/node.go new file mode 100644 index 000000000..704f952ae --- /dev/null +++ b/model/pubsub/node.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "bytes" + "encoding/gob" +) + +// Node represents a pubsub node +type Node struct { + Host string + Name string + Options Options +} + +// FromBytes deserializes a Node entity from its binary representation. +func (n *Node) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&n.Host); err != nil { + return err + } + if err := dec.Decode(&n.Name); err != nil { + return err + } + return dec.Decode(&n.Options) +} + +// ToBytes converts a Node entity to its binary representation. +func (n *Node) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := enc.Encode(n.Host); err != nil { + return err + } + if err := enc.Encode(n.Name); err != nil { + return err + } + return enc.Encode(n.Options) +} diff --git a/model/pubsub/node_test.go b/model/pubsub/node_test.go new file mode 100644 index 000000000..bcec2caeb --- /dev/null +++ b/model/pubsub/node_test.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "bytes" + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNode_Serialization(t *testing.T) { + n := Node{} + n.Name = "playing_lists" + n.Host = "jackal.im" + + n.Options.Title = "Playing lists" + n.Options.NotifySub = true + + buf := bytes.NewBuffer(nil) + require.Nil(t, n.ToBytes(buf)) + + n2 := Node{} + _ = n2.FromBytes(buf) + + require.True(t, reflect.DeepEqual(&n, &n2)) +} diff --git a/model/pubsub/opt.go b/model/pubsub/opt.go new file mode 100644 index 000000000..b861f7d45 --- /dev/null +++ b/model/pubsub/opt.go @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/ortuman/jackal/module/xep0004" +) + +const nodeConfigNamespace = "http://jabber.org/protocol/pubsub#node_config" + +const ( + titleFieldVar = "pubsub#title" + deliverNotificationsFieldVar = "pubsub#deliver_notifications" + deliverPayloadsFieldVar = "pubsub#deliver_payloads" + persistItemsFieldVar = "pubsub#persist_items" + maxItemsFieldVar = "pubsub#max_items" + accessModelFieldVar = "pubsub#access_model" + sendLastPublishedItemFieldVar = "pubsub#send_last_published_item" + rosterGroupsAllowedFieldVar = "pubsub#roster_groups_allowed" + notificationTypeFieldVar = "pubsub#notification_type" + notifyConfigFieldVar = "pubsub#notify_config" + notifyDeleteFieldVar = "pubsub#notify_delete" + notifyRetractFieldVar = "pubsub#notify_retract" + notifySubFieldVar = "pubsub#notify_sub" +) + +const ( + // Open represents 'open' access model. + Open = "open" + + // Presence represents 'presence' access model. + Presence = "presence" + + // Roster represents 'roster' access model. + Roster = "roster" + + // WhiteList represents 'whitelist' access model. + WhiteList = "whitelist" + + // Never represents 'never' send last published item option. + Never = "never" + + // OnSub represents 'on_sub' send last published item option. + OnSub = "on_sub" + + // OnSubAndPresence represents 'on_sub_and_presence' send last published item option. + OnSubAndPresence = "on_sub_and_presence" +) + +// Options represents pubsub node configuration options +type Options struct { + Title string + DeliverNotifications bool + DeliverPayloads bool + PersistItems bool + MaxItems int64 + AccessModel string + SendLastPublishedItem string + RosterGroupsAllowed []string + NotificationType string + NotifyConfig bool + NotifyDelete bool + NotifySub bool +} + +// NewOptionsFromMap returns a new node Options instance derived from an input map. +func NewOptionsFromMap(m map[string]string) (*Options, error) { + opt := &Options{} + + // extract options values + opt.Title = m[titleFieldVar] + opt.DeliverNotifications, _ = strconv.ParseBool(m[deliverNotificationsFieldVar]) + opt.DeliverPayloads, _ = strconv.ParseBool(m[deliverPayloadsFieldVar]) + opt.PersistItems, _ = strconv.ParseBool(m[persistItemsFieldVar]) + opt.MaxItems, _ = strconv.ParseInt(m[maxItemsFieldVar], 10, 32) + opt.NotificationType = m[notificationTypeFieldVar] + opt.NotifyConfig, _ = strconv.ParseBool(m[notifyConfigFieldVar]) + opt.NotifyDelete, _ = strconv.ParseBool(m[notifyDeleteFieldVar]) + opt.NotifySub, _ = strconv.ParseBool(m[notifySubFieldVar]) + + // extract roster allowed groups + allowedRosterGroupsJSON := m[rosterGroupsAllowedFieldVar] + if len(allowedRosterGroupsJSON) > 0 { + if err := json.NewDecoder(strings.NewReader(allowedRosterGroupsJSON)).Decode(&opt.RosterGroupsAllowed); err != nil { + return nil, err + } + } + + // extract options values + accessModel := m[accessModelFieldVar] + switch accessModel { + case Open, Presence, Roster, WhiteList: + opt.AccessModel = accessModel + default: + return nil, fmt.Errorf("invalid access_model value: %s", accessModel) + } + + sendLastPublishedItem := m[sendLastPublishedItemFieldVar] + switch sendLastPublishedItem { + case Never, OnSub, OnSubAndPresence: + opt.SendLastPublishedItem = sendLastPublishedItem + default: + return nil, fmt.Errorf("invalid send_last_published_item value: %s", sendLastPublishedItem) + } + return opt, nil +} + +// NewOptionsFromSubmitForm returns a new node Options instance derived from a submit form. +func NewOptionsFromSubmitForm(form *xep0004.DataForm) (*Options, error) { + opt := &Options{} + fields := form.Fields + if len(fields) == 0 { + return nil, errors.New("form empty fields") + } + // validate form type + formType := fields.ValueForFieldOfType(xep0004.FormType, xep0004.Hidden) + if form.Type != xep0004.Submit || formType != nodeConfigNamespace { + return nil, errors.New("invalid form type") + } + // extract options values + accessModel := fields.ValueForField(accessModelFieldVar) + switch accessModel { + case Open, Presence, Roster, WhiteList: + opt.AccessModel = accessModel + default: + return nil, fmt.Errorf("invalid access_model value: %s", accessModel) + } + + sendLastPublishedItem := fields.ValueForField(sendLastPublishedItemFieldVar) + switch sendLastPublishedItem { + case Never, OnSub, OnSubAndPresence: + opt.SendLastPublishedItem = sendLastPublishedItem + default: + return nil, fmt.Errorf("invalid send_last_published_item value: %s", sendLastPublishedItem) + } + + opt.Title = fields.ValueForField(titleFieldVar) + opt.DeliverNotifications, _ = strconv.ParseBool(fields.ValueForField(deliverNotificationsFieldVar)) + opt.DeliverPayloads, _ = strconv.ParseBool(fields.ValueForField(deliverPayloadsFieldVar)) + opt.PersistItems, _ = strconv.ParseBool(fields.ValueForField(persistItemsFieldVar)) + opt.RosterGroupsAllowed = fields.ValuesForField(rosterGroupsAllowedFieldVar) + opt.MaxItems, _ = strconv.ParseInt(fields.ValueForField(maxItemsFieldVar), 10, 32) + opt.NotificationType = fields.ValueForField(notificationTypeFieldVar) + opt.NotifyConfig, _ = strconv.ParseBool(fields.ValueForField(notifyConfigFieldVar)) + opt.NotifyDelete, _ = strconv.ParseBool(fields.ValueForField(notifyDeleteFieldVar)) + opt.NotifySub, _ = strconv.ParseBool(fields.ValueForField(notifySubFieldVar)) + + return opt, nil +} + +// Map returns Options map representation. +func (opt *Options) Map() (map[string]string, error) { + // marshall roster allowed groups + b, err := json.Marshal(&opt.RosterGroupsAllowed) + if err != nil { + return nil, err + } + m := make(map[string]string) + m[titleFieldVar] = opt.Title + m[deliverNotificationsFieldVar] = strconv.FormatBool(opt.DeliverNotifications) + m[deliverPayloadsFieldVar] = strconv.FormatBool(opt.DeliverPayloads) + m[persistItemsFieldVar] = strconv.FormatBool(opt.PersistItems) + m[maxItemsFieldVar] = strconv.Itoa(int(opt.MaxItems)) + m[accessModelFieldVar] = opt.AccessModel + m[rosterGroupsAllowedFieldVar] = string(b) + m[sendLastPublishedItemFieldVar] = opt.SendLastPublishedItem + m[notificationTypeFieldVar] = opt.NotificationType + m[notifyConfigFieldVar] = strconv.FormatBool(opt.NotifyConfig) + m[notifyDeleteFieldVar] = strconv.FormatBool(opt.NotifyDelete) + m[notifySubFieldVar] = strconv.FormatBool(opt.NotifySub) + return m, nil +} + +// Form returns Options form representation. +func (opt *Options) Form(rosterGroups []string) *xep0004.DataForm { + form := xep0004.DataForm{ + Type: xep0004.Form, + } + // include form type + form.Fields = append(form.Fields, xep0004.Field{ + Var: xep0004.FormType, + Type: xep0004.Hidden, + Values: []string{nodeConfigNamespace}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: titleFieldVar, + Type: xep0004.TextSingle, + Label: "Node title", + Values: []string{opt.Title}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: deliverNotificationsFieldVar, + Type: xep0004.Boolean, + Label: "Whether to deliver event notifications", + Values: []string{strconv.FormatBool(opt.DeliverNotifications)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: deliverPayloadsFieldVar, + Type: xep0004.Boolean, + Label: "Whether to deliver payloads with event notifications", + Values: []string{strconv.FormatBool(opt.DeliverPayloads)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: persistItemsFieldVar, + Type: xep0004.Boolean, + Label: "Whether to persist items to storage", + Values: []string{strconv.FormatBool(opt.PersistItems)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: maxItemsFieldVar, + Type: xep0004.Boolean, + Label: "Max number of items to persist", + Values: []string{strconv.FormatInt(opt.MaxItems, 10)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: accessModelFieldVar, + Type: xep0004.ListSingle, + Values: []string{opt.AccessModel}, + Label: "Specify the subscriber model", + Options: []xep0004.Option{ + {Label: "Open", Value: Open}, + {Label: "Presence Sharing", Value: Presence}, + {Label: "Roster Groups", Value: Roster}, + {Label: "Whitelist", Value: WhiteList}, + }, + }) + // roster groups allowed + var rosterGroupOpts []xep0004.Option + for _, rg := range rosterGroups { + rosterGroupOpts = append(rosterGroupOpts, xep0004.Option{Label: rg, Value: rg}) + } + form.Fields = append(form.Fields, xep0004.Field{ + Var: rosterGroupsAllowedFieldVar, + Type: xep0004.ListMulti, + Values: opt.RosterGroupsAllowed, + Label: "Roster groups allowed to subscribe", + Options: rosterGroupOpts, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: sendLastPublishedItemFieldVar, + Type: xep0004.ListSingle, + Label: "When to send the last published item", + Values: []string{opt.SendLastPublishedItem}, + Options: []xep0004.Option{ + {Label: "Never", Value: Never}, + {Label: "When a new subscription is processed", Value: OnSub}, + {Label: "When a new subscription is processed and whenever a subscriber comes online", Value: OnSubAndPresence}, + }, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notificationTypeFieldVar, + Type: xep0004.ListSingle, + Label: "Specify the delivery style for event notifications", + Values: []string{opt.NotificationType}, + Options: []xep0004.Option{ + {Value: "normal"}, + {Value: "headline"}, + }, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notifyConfigFieldVar, + Type: xep0004.Boolean, + Label: "Notify subscribers when the node configuration changes", + Values: []string{strconv.FormatBool(opt.NotifyConfig)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notifyDeleteFieldVar, + Type: xep0004.Boolean, + Label: "Notify subscribers when the node is deleted", + Values: []string{strconv.FormatBool(opt.NotifyDelete)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notifySubFieldVar, + Type: xep0004.Boolean, + Label: "Notify owners about new subscribers and unsubscribes", + Values: []string{strconv.FormatBool(opt.NotifySub)}, + }) + return &form +} + +// ResultForm returns Options result form representation. +func (opt *Options) ResultForm() *xep0004.DataForm { + form := xep0004.DataForm{ + Type: xep0004.Result, + } + // include form type + form.Fields = append(form.Fields, xep0004.Field{ + Var: xep0004.FormType, + Type: xep0004.Hidden, + Values: []string{nodeConfigNamespace}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: titleFieldVar, + Values: []string{opt.Title}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: deliverNotificationsFieldVar, + Values: []string{strconv.FormatBool(opt.DeliverNotifications)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: deliverPayloadsFieldVar, + Values: []string{strconv.FormatBool(opt.DeliverPayloads)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: persistItemsFieldVar, + Values: []string{strconv.FormatBool(opt.PersistItems)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: maxItemsFieldVar, + Values: []string{strconv.Itoa(int(opt.MaxItems))}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: accessModelFieldVar, + Values: []string{opt.AccessModel}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: accessModelFieldVar, + Values: []string{opt.AccessModel}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: sendLastPublishedItemFieldVar, + Values: []string{opt.SendLastPublishedItem}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notificationTypeFieldVar, + Values: []string{opt.NotificationType}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notifyConfigFieldVar, + Values: []string{strconv.FormatBool(opt.NotifyConfig)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notifyDeleteFieldVar, + Values: []string{strconv.FormatBool(opt.NotifyDelete)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: notifySubFieldVar, + Values: []string{strconv.FormatBool(opt.NotifySub)}, + }) + return &form +} diff --git a/model/pubsub/opt_test.go b/model/pubsub/opt_test.go new file mode 100644 index 000000000..556b3bbbf --- /dev/null +++ b/model/pubsub/opt_test.go @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "reflect" + "testing" + + "github.com/ortuman/jackal/module/xep0004" + "github.com/stretchr/testify/require" +) + +func TestOptions_New(t *testing.T) { + opt, err := NewOptionsFromSubmitForm(&xep0004.DataForm{}) + require.Nil(t, opt) + require.NotNil(t, err) + + form := &xep0004.DataForm{ + Type: xep0004.Submit, + Fields: xep0004.Fields{ + { + Var: "FORM_TYPE", + Type: xep0004.Hidden, + Values: []string{nodeConfigNamespace}, + }, + { + Var: titleFieldVar, + Values: []string{"Princely Musings (Atom)"}, + }, + { + Var: deliverNotificationsFieldVar, + Values: []string{"1"}, + }, + { + Var: deliverPayloadsFieldVar, + Values: []string{"1"}, + }, + { + Var: persistItemsFieldVar, + Values: []string{"1"}, + }, + { + Var: maxItemsFieldVar, + Values: []string{"10"}, + }, + { + Var: accessModelFieldVar, + Values: []string{"open"}, + }, + { + Var: sendLastPublishedItemFieldVar, + Values: []string{"never"}, + }, + { + Var: notificationTypeFieldVar, + Values: []string{"headline"}, + }, + { + Var: notifyConfigFieldVar, + Values: []string{"1"}, + }, + { + Var: notifyDeleteFieldVar, + Values: []string{"TRUE"}, + }, + { + Var: notifyRetractFieldVar, + Values: []string{"TRUE"}, + }, + { + Var: notifySubFieldVar, + Values: []string{"TRUE"}, + }, + }, + } + opt, err = NewOptionsFromSubmitForm(form) + require.NotNil(t, opt) + require.Nil(t, err) + + require.Equal(t, "Princely Musings (Atom)", opt.Title) + require.True(t, opt.DeliverNotifications) + require.True(t, opt.DeliverPayloads) + require.True(t, opt.PersistItems) + require.Equal(t, int64(10), opt.MaxItems) + require.Equal(t, Open, opt.AccessModel) + require.Equal(t, Never, opt.SendLastPublishedItem) + require.Equal(t, "headline", opt.NotificationType) + require.True(t, opt.NotifyConfig) + require.True(t, opt.NotifyDelete) + require.True(t, opt.NotifySub) + + form2 := opt.ResultForm() + form2.Type = xep0004.Submit + + opt2, err := NewOptionsFromSubmitForm(form2) + require.NotNil(t, opt2) + require.Nil(t, err) + + require.True(t, reflect.DeepEqual(&opt, &opt2)) + + optMap, _ := opt2.Map() + opt3, err := NewOptionsFromMap(optMap) + require.Nil(t, err) + require.True(t, reflect.DeepEqual(&opt, &opt3)) +} diff --git a/model/pubsub/subscription.go b/model/pubsub/subscription.go new file mode 100644 index 000000000..0e6cfc545 --- /dev/null +++ b/model/pubsub/subscription.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "bytes" + "encoding/gob" +) + +// Subscription represents a pubsub node subscription +type Subscription struct { + SubID string + JID string + Subscription string +} + +// FromBytes deserializes a Subscription entity from its binary representation. +func (s *Subscription) FromBytes(buf *bytes.Buffer) error { + dec := gob.NewDecoder(buf) + if err := dec.Decode(&s.SubID); err != nil { + return err + } + if err := dec.Decode(&s.JID); err != nil { + return err + } + return dec.Decode(&s.Subscription) +} + +// ToBytes converts a Subscription entity to its binary representation. +func (s *Subscription) ToBytes(buf *bytes.Buffer) error { + enc := gob.NewEncoder(buf) + if err := enc.Encode(s.SubID); err != nil { + return err + } + if err := enc.Encode(s.JID); err != nil { + return err + } + return enc.Encode(s.Subscription) +} diff --git a/model/pubsub/subscription_test.go b/model/pubsub/subscription_test.go new file mode 100644 index 000000000..accadc0d2 --- /dev/null +++ b/model/pubsub/subscription_test.go @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pubsubmodel + +import ( + "bytes" + "reflect" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestSubscription_Serialize(t *testing.T) { + s := Subscription{ + SubID: uuid.New().String(), + JID: "ortuman@jackal.im", + Subscription: "subscribed", + } + b := bytes.NewBuffer(nil) + require.Nil(t, s.ToBytes(b)) + + var s2 Subscription + require.Nil(t, s2.FromBytes(b)) + + require.True(t, reflect.DeepEqual(s, s2)) +} diff --git a/model/rostermodel/item.go b/model/roster/item.go similarity index 100% rename from model/rostermodel/item.go rename to model/roster/item.go diff --git a/model/rostermodel/item_test.go b/model/roster/item_test.go similarity index 100% rename from model/rostermodel/item_test.go rename to model/roster/item_test.go diff --git a/model/rostermodel/notification.go b/model/roster/notification.go similarity index 100% rename from model/rostermodel/notification.go rename to model/roster/notification.go diff --git a/model/rostermodel/notification_test.go b/model/roster/notification_test.go similarity index 100% rename from model/rostermodel/notification_test.go rename to model/roster/notification_test.go diff --git a/model/rostermodel/version.go b/model/roster/version.go similarity index 100% rename from model/rostermodel/version.go rename to model/roster/version.go diff --git a/model/rostermodel/version_test.go b/model/roster/version_test.go similarity index 100% rename from model/rostermodel/version_test.go rename to model/roster/version_test.go diff --git a/model/serializer/serializer.go b/model/serializer/serializer.go index cffc16433..c7c85ef05 100644 --- a/model/serializer/serializer.go +++ b/model/serializer/serializer.go @@ -11,7 +11,7 @@ import ( "fmt" "reflect" - "github.com/ortuman/jackal/pool" + "github.com/ortuman/jackal/util/pool" ) var bufPool = pool.NewBufferPool() @@ -26,7 +26,7 @@ type Deserializer interface { FromBytes(buf *bytes.Buffer) error } -// Serialize converts an slice of Serializer elements into its bytes representation. +// SerializeSlice converts an slice of Serializer elements into its bytes representation. func SerializeSlice(slice interface{}) ([]byte, error) { t := reflect.TypeOf(slice).Elem() if t.Kind() != reflect.Slice { @@ -58,7 +58,7 @@ func SerializeSlice(slice interface{}) ([]byte, error) { return res, nil } -// Deserialize reads an entity slice of Deserilizer elements from its bytes representation. +// DeserializeSlice reads an entity slice of Deserilizer elements from its bytes representation. func DeserializeSlice(b []byte, slice interface{}) error { t := reflect.TypeOf(slice).Elem() if t.Kind() != reflect.Slice { diff --git a/module/config.go b/module/config.go index 4e347cac8..58a43a8b6 100644 --- a/module/config.go +++ b/module/config.go @@ -10,6 +10,7 @@ import ( "github.com/ortuman/jackal/module/offline" "github.com/ortuman/jackal/module/roster" + "github.com/ortuman/jackal/module/xep0045" "github.com/ortuman/jackal/module/xep0077" "github.com/ortuman/jackal/module/xep0092" "github.com/ortuman/jackal/module/xep0199" @@ -23,6 +24,7 @@ type Config struct { Registration xep0077.Config Version xep0092.Config Ping xep0199.Config + Muc xep0045.Config } type configProxy struct { @@ -32,6 +34,7 @@ type configProxy struct { Registration xep0077.Config `yaml:"mod_registration"` Version xep0092.Config `yaml:"mod_version"` Ping xep0199.Config `yaml:"mod_ping"` + Muc xep0045.Config `yaml:"mod_muc"` } // UnmarshalYAML satisfies Unmarshaler interface. @@ -44,8 +47,8 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { enabled := make(map[string]struct{}, len(p.Enabled)) for _, mod := range p.Enabled { switch mod { - case "roster", "last_activity", "private", "vcard", "registration", "version", "blocking_command", - "ping", "offline": + case "roster", "last_activity", "private", "vcard", "registration", "pep", "version", "blocking_command", + "ping", "offline", "muc": break default: return fmt.Errorf("module.Config: unrecognized module: %s", mod) @@ -58,5 +61,6 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { cfg.Registration = p.Registration cfg.Version = p.Version cfg.Ping = p.Ping + cfg.Muc = p.Muc return nil } diff --git a/module/module.go b/module/module.go index 6abfb22d4..590c36834 100644 --- a/module/module.go +++ b/module/module.go @@ -13,13 +13,17 @@ import ( "github.com/ortuman/jackal/module/roster" "github.com/ortuman/jackal/module/xep0012" "github.com/ortuman/jackal/module/xep0030" + "github.com/ortuman/jackal/module/xep0045" "github.com/ortuman/jackal/module/xep0049" "github.com/ortuman/jackal/module/xep0054" "github.com/ortuman/jackal/module/xep0077" "github.com/ortuman/jackal/module/xep0092" + "github.com/ortuman/jackal/module/xep0115" + "github.com/ortuman/jackal/module/xep0163" "github.com/ortuman/jackal/module/xep0191" "github.com/ortuman/jackal/module/xep0199" "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/xmpp" ) @@ -33,13 +37,11 @@ type Module interface { type IQHandler interface { Module - // MatchesIQ returns whether or not an IQ should be - // processed by the module. + // MatchesIQ returns whether or not an IQ should be processed by the module. MatchesIQ(iq *xmpp.IQ) bool - // ProcessIQ processes a module IQ taking according actions - // over the associated stream. - ProcessIQ(iq *xmpp.IQ) + // ProcessIQ processes a module IQ taking according actions over the associated stream. + ProcessIQ(ctx context.Context, iq *xmpp.IQ) } // Modules structure keeps reference to a set of preconfigured modules. @@ -52,54 +54,51 @@ type Modules struct { VCard *xep0054.VCard Register *xep0077.Register Version *xep0092.Version + Pep *xep0163.Pep BlockingCmd *xep0191.BlockingCommand Ping *xep0199.Ping + Muc *xep0045.Muc - router *router.Router + router router.Router iqHandlers []IQHandler all []Module } // New returns a set of modules derived from a concrete configuration. -func New(config *Config, router *router.Router) *Modules { +func New(config *Config, router router.Router, reps repository.Container, allocationID string) *Modules { + var presenceHub = xep0115.New(router, reps.Presences(), allocationID) + m := &Modules{router: router} // XEP-0030: Service Discovery (https://xmpp.org/extensions/xep-0030.html) - m.DiscoInfo = xep0030.New(router) + m.DiscoInfo = xep0030.New(router, reps.Roster()) m.iqHandlers = append(m.iqHandlers, m.DiscoInfo) m.all = append(m.all, m.DiscoInfo) - // Roster (https://xmpp.org/rfcs/rfc3921.html#roster) - if _, ok := config.Enabled["roster"]; ok { - m.Roster = roster.New(&config.Roster, router) - m.iqHandlers = append(m.iqHandlers, m.Roster) - m.all = append(m.all, m.Roster) - } - // XEP-0012: Last Activity (https://xmpp.org/extensions/xep-0012.html) if _, ok := config.Enabled["last_activity"]; ok { - m.LastActivity = xep0012.New(m.DiscoInfo, router) + m.LastActivity = xep0012.New(m.DiscoInfo, router, reps.User(), reps.Roster()) m.iqHandlers = append(m.iqHandlers, m.LastActivity) m.all = append(m.all, m.LastActivity) } // XEP-0049: Private XML Storage (https://xmpp.org/extensions/xep-0049.html) if _, ok := config.Enabled["private"]; ok { - m.Private = xep0049.New(router) + m.Private = xep0049.New(router, reps.Private()) m.iqHandlers = append(m.iqHandlers, m.Private) m.all = append(m.all, m.Private) } // XEP-0054: vcard-temp (https://xmpp.org/extensions/xep-0054.html) if _, ok := config.Enabled["vcard"]; ok { - m.VCard = xep0054.New(m.DiscoInfo, router) + m.VCard = xep0054.New(m.DiscoInfo, router, reps.VCard()) m.iqHandlers = append(m.iqHandlers, m.VCard) m.all = append(m.all, m.VCard) } // XEP-0077: In-band registration (https://xmpp.org/extensions/xep-0077.html) if _, ok := config.Enabled["registration"]; ok { - m.Register = xep0077.New(&config.Registration, m.DiscoInfo, router) + m.Register = xep0077.New(&config.Registration, m.DiscoInfo, router, reps.User()) m.iqHandlers = append(m.iqHandlers, m.Register) m.all = append(m.all, m.Register) } @@ -113,13 +112,20 @@ func New(config *Config, router *router.Router) *Modules { // XEP-0160: Offline message storage (https://xmpp.org/extensions/xep-0160.html) if _, ok := config.Enabled["offline"]; ok { - m.Offline = offline.New(&config.Offline, m.DiscoInfo, router) + m.Offline = offline.New(&config.Offline, m.DiscoInfo, router, reps.Offline()) m.all = append(m.all, m.Offline) } + // XEP-0163: Personal Eventing Protocol (https://xmpp.org/extensions/xep-0163.html) + if _, ok := config.Enabled["pep"]; ok { + m.Pep = xep0163.New(m.DiscoInfo, presenceHub, router, reps.Roster(), reps.PubSub()) + m.iqHandlers = append(m.iqHandlers, m.Pep) + m.all = append(m.all, m.Pep) + } + // XEP-0191: Blocking Command (https://xmpp.org/extensions/xep-0191.html) if _, ok := config.Enabled["blocking_command"]; ok { - m.BlockingCmd = xep0191.New(m.DiscoInfo, m.Roster, router) + m.BlockingCmd = xep0191.New(m.DiscoInfo, presenceHub, router, reps.Roster(), reps.BlockList()) m.iqHandlers = append(m.iqHandlers, m.BlockingCmd) m.all = append(m.all, m.BlockingCmd) } @@ -130,23 +136,39 @@ func New(config *Config, router *router.Router) *Modules { m.iqHandlers = append(m.iqHandlers, m.Ping) m.all = append(m.all, m.Ping) } + + // Roster (https://xmpp.org/rfcs/rfc3921.html#roster) + if _, ok := config.Enabled["roster"]; ok { + m.iqHandlers = append(m.iqHandlers, presenceHub) + + m.Roster = roster.New(&config.Roster, presenceHub, m.Pep, router, reps.User(), reps.Roster()) + m.iqHandlers = append(m.iqHandlers, m.Roster) + m.all = append(m.all, m.Roster) + } + + // XEP-0045: Multi-User Chat (https://xmpp.org/extensions/xep-0045.html) + if _, ok := config.Enabled["muc"]; ok { + m.Muc = xep0045.New(&config.Muc, m.DiscoInfo, router, reps.Room(), reps.Occupant()) + m.all = append(m.all, m.Muc) + m.iqHandlers = append(m.iqHandlers, m.Muc) + } + return m } -// ProcessIQ process a module IQ returning 'service unavailable' -// in case it can't be properly handled. -func (m *Modules) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ process a module IQ returning 'service unavailable' in case it couldn't be properly handled. +func (m *Modules) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { for _, handler := range m.iqHandlers { if !handler.MatchesIQ(iq) { continue } - handler.ProcessIQ(iq) + handler.ProcessIQ(ctx, iq) return } // ...IQ not handled... if iq.IsGet() || iq.IsSet() { - _ = m.router.Route(iq.ServiceUnavailableError()) + _ = m.router.Route(ctx, iq.ServiceUnavailableError()) } } diff --git a/module/module_test.go b/module/module_test.go index 9de6ccf9b..58728dfc3 100644 --- a/module/module_test.go +++ b/module/module_test.go @@ -12,8 +12,12 @@ import ( "testing" "time" + "github.com/ortuman/jackal/router/host" + "github.com/google/uuid" + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -21,6 +25,8 @@ import ( yaml "gopkg.in/yaml.v2" ) +const totalModuleCount = 11 + type fakeModule struct { shutdownCh chan bool } @@ -34,26 +40,28 @@ func (m *fakeModule) Shutdown() error { func TestModules_New(t *testing.T) { mods := setupModules(t) - defer mods.Shutdown(context.Background()) + defer func() { _ = mods.Shutdown(context.Background()) }() - require.Equal(t, 10, len(mods.all)) + require.Equal(t, totalModuleCount, len(mods.all)) } func TestModules_ProcessIQ(t *testing.T) { mods := setupModules(t) - defer mods.Shutdown(context.Background()) + defer func() { _ = mods.Shutdown(context.Background()) }() j0, _ := jid.NewWithString("ortuman@jackal.im/balcony", true) j1, _ := jid.NewWithString("ortuman@jackal.im/yard", true) stm := stream.NewMockC2S(uuid.New().String(), j0) - mods.router.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j0.ToBareJID(), j0, xmpp.AvailableType)) + + mods.router.Bind(context.Background(), stm) iqID := uuid.New().String() iq := xmpp.NewIQType(iqID, xmpp.GetType) iq.SetFromJID(j0) iq.SetToJID(j1) - mods.ProcessIQ(iq) + mods.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.NotNil(t, elem) @@ -86,8 +94,13 @@ func setupModules(t *testing.T) *Modules { err = yaml.Unmarshal(b, &config) require.Nil(t, err) - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: "jackal.im", Certificate: tls.Certificate{}}}, - }) - return New(&config, r) + hosts, _ := host.New([]host.Config{{Name: "jackal.im", Certificate: tls.Certificate{}}}) + + rep, _ := storage.New(&storage.Config{Type: storage.Memory}) + r, _ := router.New( + hosts, + c2srouter.New(rep.User(), rep.BlockList()), + nil, + ) + return New(&config, r, rep, "alloc-1234") } diff --git a/module/offline/offline.go b/module/offline/offline.go index 891f72e39..a735b8677 100644 --- a/module/offline/offline.go +++ b/module/offline/offline.go @@ -6,32 +6,38 @@ package offline import ( + "context" + "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module/xep0030" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" ) const offlineNamespace = "msgoffline" +const hintsNamespace = "urn:xmpp:hints" + const offlineDeliveredCtxKey = "offline:delivered" // Offline represents an offline server stream module. type Offline struct { - cfg *Config - router *router.Router - runQueue *runqueue.RunQueue + cfg *Config + runQueue *runqueue.RunQueue + router router.Router + offlineRep repository.Offline } // New returns an offline server stream module. -func New(config *Config, disco *xep0030.DiscoInfo, router *router.Router) *Offline { +func New(config *Config, disco *xep0030.DiscoInfo, router router.Router, offlineRep repository.Offline) *Offline { r := &Offline{ - cfg: config, - router: router, - runQueue: runqueue.New("xep0030"), + cfg: config, + runQueue: runqueue.New("xep0030"), + router: router, + offlineRep: offlineRep, } if disco != nil { disco.RegisterServerFeature(offlineNamespace) @@ -40,14 +46,14 @@ func New(config *Config, disco *xep0030.DiscoInfo, router *router.Router) *Offli } // ArchiveMessage archives a new offline messages into the storage. -func (x *Offline) ArchiveMessage(message *xmpp.Message) { - x.runQueue.Run(func() { x.archiveMessage(message) }) +func (x *Offline) ArchiveMessage(ctx context.Context, message *xmpp.Message) { + x.runQueue.Run(func() { x.archiveMessage(ctx, message) }) } // DeliverOfflineMessages delivers every archived offline messages to the peer // deleting them from storage. -func (x *Offline) DeliverOfflineMessages(stm stream.C2S) { - x.runQueue.Run(func() { x.deliverOfflineMessages(stm) }) +func (x *Offline) DeliverOfflineMessages(ctx context.Context, stm stream.C2S) { + x.runQueue.Run(func() { x.deliverOfflineMessages(ctx, stm) }) } // Shutdown shuts down offline module. @@ -58,25 +64,25 @@ func (x *Offline) Shutdown() error { return nil } -func (x *Offline) archiveMessage(message *xmpp.Message) { +func (x *Offline) archiveMessage(ctx context.Context, message *xmpp.Message) { if !isMessageArchivable(message) { return } toJID := message.ToJID() - queueSize, err := storage.CountOfflineMessages(toJID.Node()) + queueSize, err := x.offlineRep.CountOfflineMessages(ctx, toJID.Node()) if err != nil { log.Error(err) return } if queueSize >= x.cfg.QueueSize { - x.router.Route(message.ServiceUnavailableError()) + _ = x.router.Route(ctx, message.ServiceUnavailableError()) return } delayed, _ := xmpp.NewMessageFromElement(message, message.FromJID(), message.ToJID()) delayed.Delay(message.FromJID().Domain(), "Offline Storage") - if err := storage.InsertOfflineMessage(delayed, toJID.Node()); err != nil { + if err := x.offlineRep.InsertOfflineMessage(ctx, delayed, toJID.Node()); err != nil { log.Error(err) - x.router.Route(message.InternalServerError()) + _ = x.router.Route(ctx, message.InternalServerError()) return } log.Infof("archived offline message... id: %s", message.ID()) @@ -88,13 +94,14 @@ func (x *Offline) archiveMessage(message *xmpp.Message) { } } -func (x *Offline) deliverOfflineMessages(stm stream.C2S) { - if stm.GetBool(offlineDeliveredCtxKey) { +func (x *Offline) deliverOfflineMessages(ctx context.Context, stm stream.C2S) { + delivered, _ := stm.Value(offlineDeliveredCtxKey).(bool) + if delivered { return // already delivered } // deliver offline messages userJID := stm.JID() - messages, err := storage.FetchOfflineMessages(userJID.Node()) + messages, err := x.offlineRep.FetchOfflineMessages(ctx, userJID.Node()) if err != nil { log.Error(err) return @@ -104,15 +111,21 @@ func (x *Offline) deliverOfflineMessages(stm stream.C2S) { } log.Infof("delivering offline messages: %s... count: %d", userJID, len(messages)) - for _, m := range messages { - _ = x.router.Route(&m) + for i := 0; i < len(messages); i++ { + _ = x.router.Route(ctx, &messages[i]) } - if err := storage.DeleteOfflineMessages(userJID.Node()); err != nil { + if err := x.offlineRep.DeleteOfflineMessages(ctx, userJID.Node()); err != nil { log.Error(err) } - stm.SetBool(offlineDeliveredCtxKey, true) + stm.SetValue(offlineDeliveredCtxKey, true) } func isMessageArchivable(message *xmpp.Message) bool { + if message.Elements().ChildNamespace("no-store", hintsNamespace) != nil { + return false + } + if message.Elements().ChildNamespace("store", hintsNamespace) != nil { + return true + } return message.IsNormal() || (message.IsChat() && message.IsMessageWithBody()) } diff --git a/module/offline/offline_test.go b/module/offline/offline_test.go index 72c18824e..09c9429a1 100644 --- a/module/offline/offline_test.go +++ b/module/offline/offline_test.go @@ -6,13 +6,16 @@ package offline import ( + "context" "crypto/tls" "testing" "time" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -21,28 +24,29 @@ import ( ) func TestOffline_ArchiveMessage(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("juliet", "jackal.im", "garden", true) stm := stream.NewMockC2S(uuid.New(), j1) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) - x := New(&Config{QueueSize: 1}, nil, r) - defer x.Shutdown() + x := New(&Config{QueueSize: 1}, nil, r, s) + defer func() { _ = x.Shutdown() }() msgID := uuid.New() msg := xmpp.NewMessageType(msgID, "normal") msg.SetFromJID(j1) msg.SetToJID(j2) - x.ArchiveMessage(msg) + x.ArchiveMessage(context.Background(), msg) // wait for insertion... time.Sleep(time.Millisecond * 250) - msgs, err := storage.FetchOfflineMessages("juliet") + msgs, err := s.FetchOfflineMessages(context.Background(), "juliet") require.Nil(t, err) require.Equal(t, 1, len(msgs)) @@ -50,7 +54,7 @@ func TestOffline_ArchiveMessage(t *testing.T) { msg2.SetFromJID(j1) msg2.SetToJID(j2) - x.ArchiveMessage(msg) + x.ArchiveMessage(context.Background(), msg) elem := stm.ReceiveElement() require.NotNil(t, elem) @@ -58,25 +62,28 @@ func TestOffline_ArchiveMessage(t *testing.T) { // deliver offline messages... stm2 := stream.NewMockC2S("abcd", j2) - r.Bind(stm2) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) - x2 := New(&Config{QueueSize: 1}, nil, r) - defer x2.Shutdown() + r.Bind(context.Background(), stm2) - x2.DeliverOfflineMessages(stm2) + x2 := New(&Config{QueueSize: 1}, nil, r, s) + defer func() { _ = x.Shutdown() }() + + x2.DeliverOfflineMessages(context.Background(), stm2) elem = stm2.ReceiveElement() require.NotNil(t, elem) require.Equal(t, msgID, elem.ID()) } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, *memorystorage.Offline) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + s := memorystorage.NewOffline() + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r, s } diff --git a/module/roster/roster.go b/module/roster/roster.go index 43ef996d0..939b528d9 100644 --- a/module/roster/roster.go +++ b/module/roster/roster.go @@ -6,17 +6,19 @@ package roster import ( + "context" "fmt" "strconv" - "sync" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/module/xep0115" + "github.com/ortuman/jackal/module/xep0163" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/pborman/uuid" @@ -34,65 +36,55 @@ type Config struct { // Roster represents a roster server stream module. type Roster struct { cfg *Config - router *router.Router - onlineJIDs sync.Map runQueue *runqueue.RunQueue + router router.Router + userRep repository.User + rosterRep repository.Roster + pep *xep0163.Pep + entityCaps *xep0115.EntityCaps } // New returns a roster server stream module. -func New(cfg *Config, router *router.Router) *Roster { +func New(cfg *Config, entityCaps *xep0115.EntityCaps, pep *xep0163.Pep, router router.Router, userRep repository.User, rosterRep repository.Roster) *Roster { r := &Roster{ - cfg: cfg, - router: router, - runQueue: runqueue.New("roster"), + cfg: cfg, + runQueue: runqueue.New("roster"), + router: router, + userRep: userRep, + rosterRep: rosterRep, + entityCaps: entityCaps, + pep: pep, } return r } -// MatchesIQ returns whether or not an IQ should be -// processed by the roster module. +// MatchesIQ returns whether or not an IQ should be processed by the roster module. func (x *Roster) MatchesIQ(iq *xmpp.IQ) bool { return iq.Elements().ChildNamespace("query", rosterNamespace) != nil } -// ProcessIQ processes a roster IQ taking according actions -// over the associated stream. -func (x *Roster) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes a roster IQ taking according actions over the associated stream. +func (x *Roster) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - stm := x.router.UserStream(iq.FromJID()) + stm := x.router.LocalStream(iq.FromJID().Node(), iq.FromJID().Resource()) if stm == nil { return } - if err := x.processIQ(iq, stm); err != nil { + if err := x.processRosterIQ(ctx, iq, stm); err != nil { log.Error(err) } }) } // ProcessPresence process an incoming roster presence. -func (x *Roster) ProcessPresence(presence *xmpp.Presence) { +func (x *Roster) ProcessPresence(ctx context.Context, presence *xmpp.Presence) { x.runQueue.Run(func() { - if err := x.processPresence(presence); err != nil { + if err := x.processPresence(ctx, presence); err != nil { log.Error(err) } }) } -// OnlinePresencesMatchingJID returns current online presences matching a given JID. -func (x *Roster) OnlinePresencesMatchingJID(j *jid.JID) []*xmpp.Presence { - var ret []*xmpp.Presence - x.onlineJIDs.Range(func(_, value interface{}) bool { - switch presence := value.(type) { - case *xmpp.Presence: - if x.onlineJIDMatchesJID(presence.FromJID(), j) { - ret = append(ret, presence) - } - } - return true - }) - return ret -} - // Shutdown shuts down roster module. func (x *Roster) Shutdown() error { c := make(chan struct{}) @@ -101,34 +93,34 @@ func (x *Roster) Shutdown() error { return nil } -func (x *Roster) processIQ(iq *xmpp.IQ, stm stream.C2S) error { +func (x *Roster) processRosterIQ(ctx context.Context, iq *xmpp.IQ, stm stream.C2S) error { var err error q := iq.Elements().ChildNamespace("query", rosterNamespace) if iq.IsGet() { - err = x.sendRoster(iq, q, stm) + err = x.sendRoster(ctx, iq, q, stm) } else if iq.IsSet() { - err = x.updateRoster(iq, q, stm) + err = x.updateRoster(ctx, iq, q, stm) } else { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) } return err } -func (x *Roster) sendRoster(iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) error { +func (x *Roster) sendRoster(ctx context.Context, iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) error { if query.Elements().Count() > 0 { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) return nil } userJID := stm.JID() log.Infof("retrieving user roster... (%s)", userJID) - itms, ver, err := storage.FetchRosterItems(userJID.Node()) + items, ver, err := x.rosterRep.FetchRosterItems(ctx, userJID.Node()) if err != nil { - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return err } - v := x.parseVer(query.Attributes().Get("ver")) + v := parseVer(query.Attributes().Get("ver")) res := iq.ResultIQ() if v == 0 || v < ver.DeletionVer { @@ -137,63 +129,63 @@ func (x *Roster) sendRoster(iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) er if x.cfg.Versioning { q.SetAttribute("ver", fmt.Sprintf("v%d", ver.Ver)) } - for _, itm := range itms { + for _, itm := range items { q.AppendElement(itm.Element()) } res.AppendElement(q) - stm.SendElement(res) + stm.SendElement(ctx, res) } else { // push roster changes - stm.SendElement(res) - for _, itm := range itms { + stm.SendElement(ctx, res) + for _, itm := range items { if itm.Ver > v { iq := xmpp.NewIQType(uuid.New(), xmpp.SetType) q := xmpp.NewElementNamespace("query", rosterNamespace) q.SetAttribute("ver", fmt.Sprintf("v%d", itm.Ver)) q.AppendElement(itm.Element()) iq.AppendElement(q) - stm.SendElement(iq) + stm.SendElement(ctx, iq) } } } - stm.SetBool(rosterRequestedCtxKey, true) + stm.SetValue(rosterRequestedCtxKey, true) return nil } -func (x *Roster) updateRoster(iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) error { - itms := query.Elements().Children("item") - if len(itms) != 1 { - stm.SendElement(iq.BadRequestError()) +func (x *Roster) updateRoster(ctx context.Context, iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) error { + items := query.Elements().Children("item") + if len(items) != 1 { + stm.SendElement(ctx, iq.BadRequestError()) return nil } - ri, err := rostermodel.NewItem(itms[0]) + ri, err := rostermodel.NewItem(items[0]) if err != nil { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) return err } switch ri.Subscription { case rostermodel.SubscriptionRemove: - if err := x.removeItem(ri, stm); err != nil { - stm.SendElement(iq.InternalServerError()) + if err := x.removeItem(ctx, ri, stm); err != nil { + stm.SendElement(ctx, iq.InternalServerError()) return err } default: - if err := x.updateItem(ri, stm); err != nil { - stm.SendElement(iq.InternalServerError()) + if err := x.updateItem(ctx, ri, stm); err != nil { + stm.SendElement(ctx, iq.InternalServerError()) return err } } - stm.SendElement(iq.ResultIQ()) + stm.SendElement(ctx, iq.ResultIQ()) return nil } -func (x *Roster) updateItem(ri *rostermodel.Item, stm stream.C2S) error { +func (x *Roster) updateItem(ctx context.Context, ri *rostermodel.Item, stm stream.C2S) error { userJID := stm.JID().ToBareJID() contactJID := ri.ContactJID() log.Infof("updating roster item - contact: %s (%s)", contactJID, userJID) - usrRi, err := storage.FetchRosterItem(userJID.Node(), contactJID.String()) + usrRi, err := x.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.String()) if err != nil { return err } @@ -214,10 +206,10 @@ func (x *Roster) updateItem(ri *rostermodel.Item, stm stream.C2S) error { Ask: ri.Ask, } } - return x.insertItem(usrRi, userJID) + return x.upsertItem(ctx, usrRi, userJID) } -func (x *Roster) removeItem(ri *rostermodel.Item, stm stream.C2S) error { +func (x *Roster) removeItem(ctx context.Context, ri *rostermodel.Item, stm stream.C2S) error { var unsubscribe, unsubscribed *xmpp.Presence userJID := stm.JID().ToBareJID() @@ -225,7 +217,7 @@ func (x *Roster) removeItem(ri *rostermodel.Item, stm stream.C2S) error { log.Infof("removing roster item: %v (%s)", contactJID, userJID) - usrRi, err := storage.FetchRosterItem(userJID.Node(), contactJID.String()) + usrRi, err := x.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.String()) if err != nil { return err } @@ -244,78 +236,83 @@ func (x *Roster) removeItem(ri *rostermodel.Item, stm stream.C2S) error { usrRi.Subscription = rostermodel.SubscriptionRemove usrRi.Ask = false - _, err := x.deleteNotification(contactJID.Node(), userJID) + _, err := x.deleteNotification(ctx, contactJID.Node(), userJID) if err != nil { return err } - if err := x.deleteItem(usrRi, userJID); err != nil { + if err := x.deleteItem(ctx, usrRi, userJID); err != nil { return err } + // auto-unsubscribe from all user virtual nodes + x.unsubscribeFromVirtualNodes(ctx, userJID.String(), contactJID) } - if x.router.IsLocalHost(contactJID.Domain()) { - cntRi, err := storage.FetchRosterItem(contactJID.Node(), userJID.String()) + + if x.router.Hosts().IsLocalHost(contactJID.Domain()) { + cntRi, err := x.rosterRep.FetchRosterItem(ctx, contactJID.Node(), userJID.String()) if err != nil { return err } if cntRi != nil { if cntRi.Subscription == rostermodel.SubscriptionFrom || cntRi.Subscription == rostermodel.SubscriptionBoth { - x.routePresencesFrom(contactJID, userJID, xmpp.UnavailableType) + x.routePresencesFrom(ctx, contactJID, userJID, xmpp.UnavailableType) } switch cntRi.Subscription { case rostermodel.SubscriptionBoth: cntRi.Subscription = rostermodel.SubscriptionTo - if x.insertItem(cntRi, contactJID); err != nil { + if err := x.upsertItem(ctx, cntRi, contactJID); err != nil { return err } fallthrough default: cntRi.Subscription = rostermodel.SubscriptionNone - if x.insertItem(cntRi, contactJID); err != nil { + if err := x.upsertItem(ctx, cntRi, contactJID); err != nil { return err } } + // auto-unsubscribe from all contact virtual nodes + x.unsubscribeFromVirtualNodes(ctx, contactJID.String(), userJID) } } if unsubscribe != nil { - x.router.Route(unsubscribe) + _ = x.router.Route(ctx, unsubscribe) } if unsubscribed != nil { - x.router.Route(unsubscribed) + _ = x.router.Route(ctx, unsubscribed) } if usrSub == rostermodel.SubscriptionFrom || usrSub == rostermodel.SubscriptionBoth { - x.routePresencesFrom(userJID, contactJID, xmpp.UnavailableType) + x.routePresencesFrom(ctx, userJID, contactJID, xmpp.UnavailableType) } return nil } -func (x *Roster) processPresence(presence *xmpp.Presence) error { +func (x *Roster) processPresence(ctx context.Context, presence *xmpp.Presence) error { switch presence.Type() { case xmpp.SubscribeType: - return x.processSubscribe(presence) + return x.processSubscribe(ctx, presence) case xmpp.SubscribedType: - return x.processSubscribed(presence) + return x.processSubscribed(ctx, presence) case xmpp.UnsubscribeType: - return x.processUnsubscribe(presence) + return x.processUnsubscribe(ctx, presence) case xmpp.UnsubscribedType: - return x.processUnsubscribed(presence) + return x.processUnsubscribed(ctx, presence) case xmpp.ProbeType: - return x.processProbePresence(presence) + return x.processProbePresence(ctx, presence) case xmpp.AvailableType, xmpp.UnavailableType: - return x.processAvailablePresence(presence) + return x.processAvailablePresence(ctx, presence) } return nil } -func (x *Roster) processSubscribe(presence *xmpp.Presence) error { +func (x *Roster) processSubscribe(ctx context.Context, presence *xmpp.Presence) error { userJID := presence.FromJID().ToBareJID() contactJID := presence.ToJID().ToBareJID() log.Infof("processing 'subscribe' - contact: %s (%s)", contactJID, userJID) - if x.router.IsLocalHost(userJID.Domain()) { - usrRi, err := storage.FetchRosterItem(userJID.Node(), contactJID.String()) + if x.router.Hosts().IsLocalHost(userJID.Domain()) { + usrRi, err := x.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.String()) if err != nil { return err } @@ -339,7 +336,7 @@ func (x *Roster) processSubscribe(presence *xmpp.Presence) error { Ask: true, } } - if x.insertItem(usrRi, userJID); err != nil { + if err := x.upsertItem(ctx, usrRi, userJID); err != nil { return err } } @@ -347,28 +344,28 @@ func (x *Roster) processSubscribe(presence *xmpp.Presence) error { p := xmpp.NewPresence(userJID, contactJID, xmpp.SubscribeType) p.AppendElements(presence.Elements().All()) - if x.router.IsLocalHost(contactJID.Domain()) { + if x.router.Hosts().IsLocalHost(contactJID.Domain()) { // archive roster approval notification - if err := x.insertOrUpdateNotification(contactJID.Node(), userJID, p); err != nil { + if err := x.upsertNotification(ctx, contactJID.Node(), userJID, p); err != nil { return err } } - x.router.Route(p) + _ = x.router.Route(ctx, p) return nil } -func (x *Roster) processSubscribed(presence *xmpp.Presence) error { +func (x *Roster) processSubscribed(ctx context.Context, presence *xmpp.Presence) error { userJID := presence.ToJID().ToBareJID() contactJID := presence.FromJID().ToBareJID() log.Infof("processing 'subscribed' - user: %s (%s)", userJID, contactJID) - if x.router.IsLocalHost(contactJID.Domain()) { - _, err := x.deleteNotification(contactJID.Node(), userJID) + if x.router.Hosts().IsLocalHost(contactJID.Domain()) { + _, err := x.deleteNotification(ctx, contactJID.Node(), userJID) if err != nil { return err } - cntRi, err := storage.FetchRosterItem(contactJID.Node(), userJID.String()) + cntRi, err := x.rosterRep.FetchRosterItem(ctx, contactJID.Node(), userJID.String()) if err != nil { return err } @@ -388,7 +385,9 @@ func (x *Roster) processSubscribed(presence *xmpp.Presence) error { Ask: false, } } - if x.insertItem(cntRi, contactJID); err != nil { + x.subscribeToAllVirtualNodes(ctx, contactJID.String(), userJID) // auto-subscribe to all contact virtual nodes + + if err := x.upsertItem(ctx, cntRi, contactJID); err != nil { return err } } @@ -396,8 +395,8 @@ func (x *Roster) processSubscribed(presence *xmpp.Presence) error { p := xmpp.NewPresence(contactJID, userJID, xmpp.SubscribedType) p.AppendElements(presence.Elements().All()) - if x.router.IsLocalHost(userJID.Domain()) { - usrRi, err := storage.FetchRosterItem(userJID.Node(), contactJID.String()) + if x.router.Hosts().IsLocalHost(userJID.Domain()) { + usrRi, err := x.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.String()) if err != nil { return err } @@ -411,25 +410,26 @@ func (x *Roster) processSubscribed(presence *xmpp.Presence) error { return nil } usrRi.Ask = false - if x.insertItem(usrRi, userJID); err != nil { + if err := x.upsertItem(ctx, usrRi, userJID); err != nil { return err } } } - x.router.Route(p) - x.routePresencesFrom(contactJID, userJID, xmpp.AvailableType) + _ = x.router.Route(ctx, p) + x.routePresencesFrom(ctx, contactJID, userJID, xmpp.AvailableType) + return nil } -func (x *Roster) processUnsubscribe(presence *xmpp.Presence) error { +func (x *Roster) processUnsubscribe(ctx context.Context, presence *xmpp.Presence) error { userJID := presence.FromJID().ToBareJID() contactJID := presence.ToJID().ToBareJID() log.Infof("processing 'unsubscribe' - contact: %s (%s)", contactJID, userJID) var usrSub string - if x.router.IsLocalHost(userJID.Domain()) { - usrRi, err := storage.FetchRosterItem(userJID.Node(), contactJID.String()) + if x.router.Hosts().IsLocalHost(userJID.Domain()) { + usrRi, err := x.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.String()) if err != nil { return err } @@ -442,7 +442,7 @@ func (x *Roster) processUnsubscribe(presence *xmpp.Presence) error { default: usrRi.Subscription = rostermodel.SubscriptionNone } - if x.insertItem(usrRi, userJID); err != nil { + if err := x.upsertItem(ctx, usrRi, userJID); err != nil { return err } } @@ -451,8 +451,8 @@ func (x *Roster) processUnsubscribe(presence *xmpp.Presence) error { p := xmpp.NewPresence(userJID, contactJID, xmpp.UnsubscribeType) p.AppendElements(presence.Elements().All()) - if x.router.IsLocalHost(contactJID.Domain()) { - cntRi, err := storage.FetchRosterItem(contactJID.Node(), userJID.String()) + if x.router.Hosts().IsLocalHost(contactJID.Domain()) { + cntRi, err := x.rosterRep.FetchRosterItem(ctx, contactJID.Node(), userJID.String()) if err != nil { return err } @@ -463,28 +463,30 @@ func (x *Roster) processUnsubscribe(presence *xmpp.Presence) error { default: cntRi.Subscription = rostermodel.SubscriptionNone } - if x.insertItem(cntRi, contactJID); err != nil { + if err := x.upsertItem(ctx, cntRi, contactJID); err != nil { return err } } + // auto-unsubscribe from all contact virtual nodes + x.unsubscribeFromVirtualNodes(ctx, contactJID.String(), userJID) } - x.router.Route(p) + _ = x.router.Route(ctx, p) if usrSub == rostermodel.SubscriptionTo || usrSub == rostermodel.SubscriptionBoth { - x.routePresencesFrom(contactJID, userJID, xmpp.UnavailableType) + x.routePresencesFrom(ctx, contactJID, userJID, xmpp.UnavailableType) } return nil } -func (x *Roster) processUnsubscribed(presence *xmpp.Presence) error { +func (x *Roster) processUnsubscribed(ctx context.Context, presence *xmpp.Presence) error { userJID := presence.ToJID().ToBareJID() contactJID := presence.FromJID().ToBareJID() log.Infof("processing 'unsubscribed' - user: %s (%s)", userJID, contactJID) var cntSub string - if x.router.IsLocalHost(contactJID.Domain()) { - deleted, err := x.deleteNotification(contactJID.Node(), userJID) + if x.router.Hosts().IsLocalHost(contactJID.Domain()) { + deleted, err := x.deleteNotification(ctx, contactJID.Node(), userJID) if err != nil { return err } @@ -492,7 +494,7 @@ func (x *Roster) processUnsubscribed(presence *xmpp.Presence) error { if deleted { goto routePresence } - cntRi, err := storage.FetchRosterItem(contactJID.Node(), userJID.String()) + cntRi, err := x.rosterRep.FetchRosterItem(ctx, contactJID.Node(), userJID.String()) if err != nil { return err } @@ -505,18 +507,20 @@ func (x *Roster) processUnsubscribed(presence *xmpp.Presence) error { default: cntRi.Subscription = rostermodel.SubscriptionNone } - if x.insertItem(cntRi, contactJID); err != nil { + if err := x.upsertItem(ctx, cntRi, contactJID); err != nil { return err } } + // auto-unsubscribe from all contact virtual nodes + x.unsubscribeFromVirtualNodes(ctx, contactJID.String(), userJID) } routePresence: // stamp the presence stanza of type "unsubscribed" with the contact's bare JID as the 'from' address p := xmpp.NewPresence(contactJID, userJID, xmpp.UnsubscribedType) p.AppendElements(presence.Elements().All()) - if x.router.IsLocalHost(userJID.Domain()) { - usrRi, err := storage.FetchRosterItem(userJID.Node(), contactJID.String()) + if x.router.Hosts().IsLocalHost(userJID.Domain()) { + usrRi, err := x.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.String()) if err != nil { return err } @@ -530,76 +534,102 @@ routePresence: } } usrRi.Ask = false - if x.insertItem(usrRi, userJID); err != nil { + if err := x.upsertItem(ctx, usrRi, userJID); err != nil { return err } } } - x.router.Route(p) + _ = x.router.Route(ctx, p) if cntSub == rostermodel.SubscriptionFrom || cntSub == rostermodel.SubscriptionBoth { - x.routePresencesFrom(contactJID, userJID, xmpp.UnavailableType) + x.routePresencesFrom(ctx, contactJID, userJID, xmpp.UnavailableType) } return nil } -func (x *Roster) processProbePresence(presence *xmpp.Presence) error { - userJID := presence.ToJID().ToBareJID() - contactJID := presence.FromJID().ToBareJID() +func (x *Roster) processProbePresence(ctx context.Context, presence *xmpp.Presence) error { + userJID := presence.FromJID().ToBareJID() + contactJID := presence.ToJID().ToBareJID() log.Infof("processing 'probe' - user: %s (%s)", userJID, contactJID) - ri, err := storage.FetchRosterItem(userJID.Node(), contactJID.String()) + if !x.router.Hosts().IsLocalHost(contactJID.Domain()) { + _ = x.router.Route(ctx, presence) + return nil + } + ri, err := x.rosterRep.FetchRosterItem(ctx, contactJID.Node(), userJID.String()) if err != nil { return err } - usr, err := storage.FetchUser(userJID.Node()) + if ri == nil || (ri.Subscription != rostermodel.SubscriptionBoth && ri.Subscription != rostermodel.SubscriptionFrom) { + return nil // silently ignore + } + availPresences, err := x.entityCaps.PresencesMatchingJID(ctx, contactJID) if err != nil { return err } - if usr == nil || ri == nil || (ri.Subscription != rostermodel.SubscriptionBoth && ri.Subscription != rostermodel.SubscriptionFrom) { - x.router.Route(xmpp.NewPresence(userJID, contactJID, xmpp.UnsubscribedType)) + if len(availPresences) == 0 { // send last known presence + usr, err := x.userRep.FetchUser(ctx, contactJID.Node()) + if err != nil { + return err + } + if usr == nil || usr.LastPresence == nil { + return nil + } + p := xmpp.NewPresence(usr.LastPresence.FromJID(), userJID, usr.LastPresence.Type()) + p.AppendElements(usr.LastPresence.Elements().All()) + _ = x.router.Route(ctx, p) return nil } - if usr.LastPresence != nil { - p := xmpp.NewPresence(usr.LastPresence.FromJID(), contactJID, usr.LastPresence.Type()) - p.AppendElements(usr.LastPresence.Elements().All()) - x.router.Route(p) + for _, availPresence := range availPresences { + p := xmpp.NewPresence(availPresence.Presence.FromJID(), userJID, xmpp.AvailableType) + p.AppendElements(availPresence.Presence.Elements().All()) + _ = x.router.Route(ctx, p) } return nil } -func (x *Roster) processAvailablePresence(presence *xmpp.Presence) error { +func (x *Roster) processAvailablePresence(ctx context.Context, presence *xmpp.Presence) error { fromJID := presence.FromJID() userJID := fromJID.ToBareJID() contactJID := presence.ToJID().ToBareJID() - replyOnBehalf := x.router.IsLocalHost(userJID.Domain()) && userJID.Matches(contactJID, jid.MatchesBare) + replyOnBehalf := x.router.Hosts().IsLocalHost(userJID.Domain()) && userJID.MatchesWithOptions(contactJID, jid.MatchesBare) // keep track of available presences if presence.IsAvailable() { log.Infof("processing 'available' - user: %s", fromJID) - if _, loaded := x.onlineJIDs.LoadOrStore(fromJID.String(), presence); !loaded { - if replyOnBehalf { - if err := x.deliverRosterPresences(userJID); err != nil { - return err - } + + // register presence + inserted, err := x.entityCaps.RegisterPresence(ctx, presence) + if err != nil { + return err + } + if inserted && replyOnBehalf { + if err := x.deliverRosterPresences(ctx, userJID); err != nil { + return err } + x.sendVirtualNodesLastItems(ctx, fromJID) } } else { log.Infof("processing 'unavailable' - user: %s", fromJID) - x.onlineJIDs.Delete(fromJID.String()) + + // unregister presence + if err := x.entityCaps.UnregisterPresence(ctx, presence.FromJID()); err != nil { + return err + } } if replyOnBehalf { - return x.broadcastPresence(presence) + return x.broadcastPresence(ctx, presence) } - return x.router.Route(presence) + _ = x.router.Route(ctx, presence) + return nil } -func (x *Roster) deliverRosterPresences(userJID *jid.JID) error { +func (x *Roster) deliverRosterPresences(ctx context.Context, userJID *jid.JID) error { // first, deliver pending approval notifications... - rns, err := storage.FetchRosterNotifications(userJID.Node()) + rns, err := x.rosterRep.FetchRosterNotifications(ctx, userJID.Node()) if err != nil { return err } @@ -607,11 +637,11 @@ func (x *Roster) deliverRosterPresences(userJID *jid.JID) error { fromJID, _ := jid.NewWithString(rn.JID, true) p := xmpp.NewPresence(fromJID, userJID, xmpp.SubscribeType) p.AppendElements(rn.Presence.Elements().All()) - _ = x.router.Route(p) + _ = x.router.Route(ctx, p) } // deliver roster online presences - items, _, err := storage.FetchRosterItems(userJID.Node()) + items, _, err := x.rosterRep.FetchRosterItems(ctx, userJID.Node()) if err != nil { return err } @@ -619,19 +649,19 @@ func (x *Roster) deliverRosterPresences(userJID *jid.JID) error { switch item.Subscription { case rostermodel.SubscriptionTo, rostermodel.SubscriptionBoth: contactJID := item.ContactJID() - if !x.router.IsLocalHost(contactJID.Domain()) { - _ = x.router.Route(xmpp.NewPresence(userJID, contactJID, xmpp.ProbeType)) + if !x.router.Hosts().IsLocalHost(contactJID.Domain()) { + _ = x.router.Route(ctx, xmpp.NewPresence(userJID, contactJID, xmpp.ProbeType)) continue } - x.routePresencesFrom(contactJID, userJID, xmpp.AvailableType) + x.routePresencesFrom(ctx, contactJID, userJID, xmpp.AvailableType) } } return nil } -func (x *Roster) broadcastPresence(presence *xmpp.Presence) error { +func (x *Roster) broadcastPresence(ctx context.Context, presence *xmpp.Presence) error { fromJID := presence.FromJID() - items, _, err := storage.FetchRosterItems(fromJID.Node()) + items, _, err := x.rosterRep.FetchRosterItems(ctx, fromJID.Node()) if err != nil { return err } @@ -640,15 +670,15 @@ func (x *Roster) broadcastPresence(presence *xmpp.Presence) error { case rostermodel.SubscriptionFrom, rostermodel.SubscriptionBoth: p := xmpp.NewPresence(fromJID, itm.ContactJID(), presence.Type()) p.AppendElements(presence.Elements().All()) - _ = x.router.Route(p) + _ = x.router.Route(ctx, p) } } // update last received presence - if usr, err := storage.FetchUser(fromJID.Node()); err != nil { + if usr, err := x.userRep.FetchUser(ctx, fromJID.Node()); err != nil { return err } else if usr != nil { - return storage.InsertOrUpdateUser(&model.User{ + return x.userRep.UpsertUser(ctx, &model.User{ Username: usr.Username, Password: usr.Password, LastPresence: presence, @@ -657,90 +687,101 @@ func (x *Roster) broadcastPresence(presence *xmpp.Presence) error { return nil } -func (x *Roster) onlineJIDMatchesJID(onlineJID, j *jid.JID) bool { - if j.IsFullWithUser() { - return onlineJID.Matches(j, jid.MatchesNode|jid.MatchesDomain|jid.MatchesResource) - } else if j.IsFullWithServer() { - return onlineJID.Matches(j, jid.MatchesDomain|jid.MatchesResource) - } else if j.IsBare() { - return onlineJID.Matches(j, jid.MatchesNode|jid.MatchesDomain) - } - return onlineJID.Matches(j, jid.MatchesDomain) -} - -func (x *Roster) insertItem(ri *rostermodel.Item, pushTo *jid.JID) error { - v, err := storage.InsertOrUpdateRosterItem(ri) +func (x *Roster) upsertItem(ctx context.Context, ri *rostermodel.Item, pushTo *jid.JID) error { + v, err := x.rosterRep.UpsertRosterItem(ctx, ri) if err != nil { return err } ri.Ver = v.Ver - return x.pushItem(ri, pushTo) + return x.pushItem(ctx, ri, pushTo) } -func (x *Roster) deleteItem(ri *rostermodel.Item, pushTo *jid.JID) error { - v, err := storage.DeleteRosterItem(ri.Username, ri.JID) +func (x *Roster) deleteItem(ctx context.Context, ri *rostermodel.Item, pushTo *jid.JID) error { + v, err := x.rosterRep.DeleteRosterItem(ctx, ri.Username, ri.JID) if err != nil { return err } ri.Ver = v.Ver - return x.pushItem(ri, pushTo) + return x.pushItem(ctx, ri, pushTo) } -func (x *Roster) pushItem(ri *rostermodel.Item, to *jid.JID) error { +func (x *Roster) pushItem(ctx context.Context, ri *rostermodel.Item, to *jid.JID) error { query := xmpp.NewElementNamespace("query", rosterNamespace) if x.cfg.Versioning { query.SetAttribute("ver", fmt.Sprintf("v%d", ri.Ver)) } query.AppendElement(ri.Element()) - stms := x.router.UserStreams(to.Node()) - for _, stm := range stms { - if !stm.GetBool(rosterRequestedCtxKey) { + streams := x.router.LocalStreams(to.Node()) + for _, stm := range streams { + requested, _ := stm.Value(rosterRequestedCtxKey).(bool) + if !requested { continue } pushEl := xmpp.NewIQType(uuid.New(), xmpp.SetType) pushEl.SetTo(stm.JID().String()) pushEl.AppendElement(query) - stm.SendElement(pushEl) + stm.SendElement(ctx, pushEl) } return nil } -func (x *Roster) deleteNotification(contact string, userJID *jid.JID) (deleted bool, err error) { - rn, err := storage.FetchRosterNotification(contact, userJID.String()) +func (x *Roster) deleteNotification(ctx context.Context, contact string, userJID *jid.JID) (deleted bool, err error) { + rn, err := x.rosterRep.FetchRosterNotification(ctx, contact, userJID.String()) if err != nil { return false, err } if rn == nil { return false, nil } - if err := storage.DeleteRosterNotification(contact, userJID.String()); err != nil { + if err := x.rosterRep.DeleteRosterNotification(ctx, contact, userJID.String()); err != nil { return false, err } return true, nil } -func (x *Roster) insertOrUpdateNotification(contact string, userJID *jid.JID, presence *xmpp.Presence) error { +func (x *Roster) upsertNotification(ctx context.Context, contact string, userJID *jid.JID, presence *xmpp.Presence) error { rn := &rostermodel.Notification{ Contact: contact, JID: userJID.String(), Presence: presence, } - return storage.InsertOrUpdateRosterNotification(rn) + return x.rosterRep.UpsertRosterNotification(ctx, rn) } -func (x *Roster) routePresencesFrom(from *jid.JID, to *jid.JID, presenceType string) { - stms := x.router.UserStreams(from.Node()) - for _, stm := range stms { +func (x *Roster) routePresencesFrom(ctx context.Context, from *jid.JID, to *jid.JID, presenceType string) { + streams := x.router.LocalStreams(from.Node()) + for _, stm := range streams { p := xmpp.NewPresence(stm.JID(), to.ToBareJID(), presenceType) if presence := stm.Presence(); presence != nil && presence.IsAvailable() { p.AppendElements(presence.Elements().All()) } - x.router.Route(p) + _ = x.router.Route(ctx, p) + } +} + +func (x *Roster) subscribeToAllVirtualNodes(ctx context.Context, hostJID string, jid *jid.JID) { + if x.pep == nil { + return + } + x.pep.SubscribeToAll(ctx, hostJID, jid) +} + +func (x *Roster) unsubscribeFromVirtualNodes(ctx context.Context, hostJID string, jid *jid.JID) { + if x.pep == nil { + return + } + x.pep.UnsubscribeFromAll(ctx, hostJID, jid) +} + +func (x *Roster) sendVirtualNodesLastItems(ctx context.Context, jid *jid.JID) { + if x.pep == nil { + return } + x.pep.DeliverLastItems(ctx, jid) } -func (x *Roster) parseVer(ver string) int { +func parseVer(ver string) int { if len(ver) > 0 && ver[0] == 'v' { v, _ := strconv.Atoi(ver[1:]) return v diff --git a/module/roster/roster_test.go b/module/roster/roster_test.go index d24d4be88..0aa3a0502 100644 --- a/module/roster/roster_test.go +++ b/module/roster/roster_test.go @@ -6,15 +6,19 @@ package roster import ( + "context" "crypto/tls" "testing" "time" + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/module/xep0115" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + "github.com/ortuman/jackal/router/host" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -23,11 +27,10 @@ import ( ) func TestRoster_MatchesIQ(t *testing.T) { - rtr, _, shutdown := setupTest("jackal.im") - defer shutdown() + rtr, userRep, presencesRep, rosterRep := setupTest("jackal.im") - r := New(&Config{}, rtr) - defer r.Shutdown() + r := New(&Config{}, xep0115.New(rtr, presencesRep, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq.AppendElement(xmpp.NewElementNamespace("query", rosterNamespace)) @@ -36,16 +39,15 @@ func TestRoster_MatchesIQ(t *testing.T) { } func TestRoster_FetchRoster(t *testing.T) { - rtr, s, shutdown := setupTest("jackal.im") - defer shutdown() + rtr, userRep, presencesRep, rosterRep := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j1) - rtr.Bind(stm) + rtr.Bind(context.Background(), stm) - r := New(&Config{}, rtr) - defer r.Shutdown() + r := New(&Config{}, xep0115.New(rtr, presencesRep, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.ResultType) iq.SetFromJID(j1) @@ -54,17 +56,17 @@ func TestRoster_FetchRoster(t *testing.T) { q.AppendElement(xmpp.NewElementName("q2")) iq.AppendElement(q) - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) iq.SetType(xmpp.GetType) - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) q.ClearElements() - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, "iq", elem.Name()) require.Equal(t, xmpp.ResultType, elem.Type()) @@ -80,7 +82,7 @@ func TestRoster_FetchRoster(t *testing.T) { Ask: true, Groups: []string{"people", "friends"}, } - storage.InsertOrUpdateRosterItem(ri1) + _, _ = rosterRep.UpsertRosterItem(context.Background(), ri1) ri2 := &rostermodel.Item{ Username: "ortuman", @@ -90,19 +92,21 @@ func TestRoster_FetchRoster(t *testing.T) { Ask: true, Groups: []string{"others"}, } - storage.InsertOrUpdateRosterItem(ri2) + _, _ = rosterRep.UpsertRosterItem(context.Background(), ri2) - r = New(&Config{Versioning: true}, rtr) - defer r.Shutdown() + r = New(&Config{Versioning: true}, xep0115.New(rtr, nil, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, "iq", elem.Name()) require.Equal(t, xmpp.ResultType, elem.Type()) query2 := elem.Elements().ChildNamespace("query", rosterNamespace) require.Equal(t, 2, query2.Elements().Count()) - require.True(t, stm.GetBool(rosterRequestedCtxKey)) + + requested, _ := stm.Value(rosterRequestedCtxKey).(bool) + require.True(t, requested) // test versioning iq = xmpp.NewIQType(uuid.New(), xmpp.GetType) @@ -112,7 +116,7 @@ func TestRoster_FetchRoster(t *testing.T) { q.SetAttribute("ver", "v1") iq.AppendElement(q) - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, "iq", elem.Name()) require.Equal(t, xmpp.ResultType, elem.Type()) @@ -126,19 +130,18 @@ func TestRoster_FetchRoster(t *testing.T) { item := query2.Elements().Child("item") require.Equal(t, "romeo@jackal.im", item.Attributes().Get("jid")) - s.EnableMockedError() - r = New(&Config{}, rtr) - defer r.Shutdown() + memorystorage.EnableMockedError() + r = New(&Config{}, xep0115.New(rtr, nil, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() } func TestRoster_Update(t *testing.T) { - rtr, _, shutdown := setupTest("jackal.im") - defer shutdown() + rtr, userRep, presencesRep, rosterRep := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "garden", true) j2, _ := jid.New("ortuman", "jackal.im", "balcony", true) @@ -147,13 +150,13 @@ func TestRoster_Update(t *testing.T) { stm1.SetAuthenticated(true) stm2 := stream.NewMockC2S(uuid.New(), j2) stm2.SetAuthenticated(true) - stm2.SetBool(rosterRequestedCtxKey, true) + stm2.SetValue(rosterRequestedCtxKey, true) - r := New(&Config{}, rtr) - defer r.Shutdown() + r := New(&Config{}, xep0115.New(rtr, presencesRep, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() - rtr.Bind(stm1) - rtr.Bind(stm2) + rtr.Bind(context.Background(), stm1) + rtr.Bind(context.Background(), stm2) iqID := uuid.New() iq := xmpp.NewIQType(iqID, xmpp.SetType) @@ -168,14 +171,14 @@ func TestRoster_Update(t *testing.T) { q.AppendElement(item) iq.AppendElement(q) - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem := stm1.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) q.ClearElements() q.AppendElement(item) - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem = stm1.ReceiveElement() require.Equal(t, "iq", elem.Name()) require.Equal(t, xmpp.ResultType, elem.Type()) @@ -190,13 +193,13 @@ func TestRoster_Update(t *testing.T) { q.ClearElements() q.AppendElement(item) - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem = stm1.ReceiveElement() require.Equal(t, "iq", elem.Name()) require.Equal(t, xmpp.ResultType, elem.Type()) require.Equal(t, iqID, elem.ID()) - ri, err := storage.FetchRosterItem("ortuman", "noelia@jackal.im") + ri, err := rosterRep.FetchRosterItem(context.Background(), "ortuman", "noelia@jackal.im") require.Nil(t, err) require.NotNil(t, ri) require.Equal(t, "ortuman", ri.Username) @@ -205,17 +208,16 @@ func TestRoster_Update(t *testing.T) { } func TestRoster_RemoveItem(t *testing.T) { - rtr, _, shutdown := setupTest("jackal.im") - defer shutdown() + rtr, userRep, presencesRep, rosterRep := setupTest("jackal.im") // insert contact's roster item - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "ortuman", JID: "noelia@jackal.im", Name: "My Juliet", Subscription: rostermodel.SubscriptionBoth, }) - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "noelia", JID: "ortuman@jackal.im", Name: "My Romeo", @@ -224,10 +226,12 @@ func TestRoster_RemoveItem(t *testing.T) { j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - rtr.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + rtr.Bind(context.Background(), stm) - r := New(&Config{}, rtr) - defer r.Shutdown() + r := New(&Config{}, xep0115.New(rtr, presencesRep, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() // remove item iqID := uuid.New() @@ -242,18 +246,17 @@ func TestRoster_RemoveItem(t *testing.T) { q.AppendElement(item) iq.AppendElement(q) - r.ProcessIQ(iq) + r.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, iqID, elem.ID()) - ri, err := storage.FetchRosterItem("ortuman", "noelia@jackal.im") + ri, err := rosterRep.FetchRosterItem(context.Background(), "ortuman", "noelia@jackal.im") require.Nil(t, err) require.Nil(t, ri) } func TestRoster_OnlineJIDs(t *testing.T) { - rtr, _, shutdown := setupTest("jackal.im") - defer shutdown() + rtr, userRep, presencesRep, rosterRep := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("noelia", "jackal.im", "garden", true) @@ -266,39 +269,43 @@ func TestRoster_OnlineJIDs(t *testing.T) { stm2 := stream.NewMockC2S(uuid.New(), j2) stm2.SetAuthenticated(true) - rtr.Bind(stm1) - rtr.Bind(stm2) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + + rtr.Bind(context.Background(), stm1) + rtr.Bind(context.Background(), stm2) // user entity - storage.InsertOrUpdateUser(&model.User{ + _ = userRep.UpsertUser(context.Background(), &model.User{ Username: "ortuman", LastPresence: xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.UnavailableType), }) // roster items - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "noelia", JID: "ortuman@jackal.im", Subscription: rostermodel.SubscriptionBoth, }) - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "ortuman", JID: "noelia@jackal.im", Subscription: rostermodel.SubscriptionBoth, }) // pending notification - storage.InsertOrUpdateRosterNotification(&rostermodel.Notification{ + _ = rosterRep.UpsertRosterNotification(context.Background(), &rostermodel.Notification{ Contact: "ortuman", JID: j3.ToBareJID().String(), Presence: xmpp.NewPresence(j3.ToBareJID(), j1.ToBareJID(), xmpp.SubscribeType), }) - r := New(&Config{}, rtr) + ph := xep0115.New(rtr, presencesRep, "alloc-1234") + r := New(&Config{}, ph, nil, rtr, userRep, rosterRep) defer func() { _ = r.Shutdown() }() // online presence... - r.ProcessPresence(xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.AvailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.AvailableType)) time.Sleep(time.Millisecond * 150) // wait until processed... @@ -315,53 +322,62 @@ func TestRoster_OnlineJIDs(t *testing.T) { require.Equal(t, xmpp.AvailableType, elem.Type()) // check if last presence was updated - usr, err := storage.FetchUser("ortuman") + usr, err := userRep.FetchUser(context.Background(), "ortuman") require.Nil(t, err) require.NotNil(t, usr) require.NotNil(t, usr.LastPresence) require.Equal(t, xmpp.AvailableType, usr.LastPresence.Type()) // send remaining online presences... - r.ProcessPresence(xmpp.NewPresence(j2, j2.ToBareJID(), xmpp.AvailableType)) - r.ProcessPresence(xmpp.NewPresence(j3, j3.ToBareJID(), xmpp.AvailableType)) - r.ProcessPresence(xmpp.NewPresence(j4, j1.ToBareJID(), xmpp.AvailableType)) - r.ProcessPresence(xmpp.NewPresence(j5, j1.ToBareJID(), xmpp.AvailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j2, j2.ToBareJID(), xmpp.AvailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j3, j3.ToBareJID(), xmpp.AvailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j4, j1.ToBareJID(), xmpp.AvailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j5, j1.ToBareJID(), xmpp.AvailableType)) time.Sleep(time.Millisecond * 150) // wait until processed... - require.Equal(t, 1, len(r.OnlinePresencesMatchingJID(j1))) + ln1, _ := ph.PresencesMatchingJID(context.Background(), j1) + require.Equal(t, 1, len(ln1)) j6, _ := jid.NewWithString("jackal.im", true) - require.Equal(t, 4, len(r.OnlinePresencesMatchingJID(j6))) + ln6, _ := ph.PresencesMatchingJID(context.Background(), j6) + require.Equal(t, 4, len(ln6)) j7, _ := jid.NewWithString("jabber.org", true) - require.Equal(t, 1, len(r.OnlinePresencesMatchingJID(j7))) + ln7, _ := ph.PresencesMatchingJID(context.Background(), j7) + require.Equal(t, 1, len(ln7)) j8, _ := jid.NewWithString("jackal.im/balcony", true) - require.Equal(t, 2, len(r.OnlinePresencesMatchingJID(j8))) + ln8, _ := ph.PresencesMatchingJID(context.Background(), j8) + require.Equal(t, 2, len(ln8)) j9, _ := jid.NewWithString("ortuman@jackal.im", true) - require.Equal(t, 2, len(r.OnlinePresencesMatchingJID(j9))) + ln9, _ := ph.PresencesMatchingJID(context.Background(), j9) + require.Equal(t, 2, len(ln9)) // send unavailable presences... - r.ProcessPresence(xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.UnavailableType)) - r.ProcessPresence(xmpp.NewPresence(j2, j2.ToBareJID(), xmpp.UnavailableType)) - r.ProcessPresence(xmpp.NewPresence(j3, j3.ToBareJID(), xmpp.UnavailableType)) - r.ProcessPresence(xmpp.NewPresence(j4, j4.ToBareJID(), xmpp.UnavailableType)) - r.ProcessPresence(xmpp.NewPresence(j5, j1.ToBareJID(), xmpp.UnavailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.UnavailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j2, j2.ToBareJID(), xmpp.UnavailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j3, j3.ToBareJID(), xmpp.UnavailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j4, j4.ToBareJID(), xmpp.UnavailableType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j5, j1.ToBareJID(), xmpp.UnavailableType)) time.Sleep(time.Millisecond * 150) // wait until processed... - require.Equal(t, 0, len(r.OnlinePresencesMatchingJID(j1))) - require.Equal(t, 0, len(r.OnlinePresencesMatchingJID(j6))) - require.Equal(t, 0, len(r.OnlinePresencesMatchingJID(j7))) - require.Equal(t, 0, len(r.OnlinePresencesMatchingJID(j8))) - require.Equal(t, 0, len(r.OnlinePresencesMatchingJID(j9))) + ln1, _ = ph.PresencesMatchingJID(context.Background(), j1) + ln6, _ = ph.PresencesMatchingJID(context.Background(), j6) + ln7, _ = ph.PresencesMatchingJID(context.Background(), j7) + ln8, _ = ph.PresencesMatchingJID(context.Background(), j8) + ln9, _ = ph.PresencesMatchingJID(context.Background(), j9) + require.Equal(t, 0, len(ln1)) + require.Equal(t, 0, len(ln6)) + require.Equal(t, 0, len(ln7)) + require.Equal(t, 0, len(ln8)) + require.Equal(t, 0, len(ln9)) } func TestRoster_Probe(t *testing.T) { - rtr, _, shutdown := setupTest("jackal.im") - defer shutdown() + rtr, userRep, presencesRep, rosterRep := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("noelia", "jackal.im", "garden", true) @@ -369,127 +385,119 @@ func TestRoster_Probe(t *testing.T) { stm := stream.NewMockC2S(uuid.New(), j1) stm.SetAuthenticated(true) - rtr.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) - r := New(&Config{}, rtr) - defer r.Shutdown() + rtr.Bind(context.Background(), stm) - // user doesn't exist... - r.ProcessPresence(xmpp.NewPresence(j1, j2, xmpp.ProbeType)) - elem := stm.ReceiveElement() - require.Equal(t, "presence", elem.Name()) - require.Equal(t, "noelia@jackal.im", elem.From()) - require.Equal(t, xmpp.UnsubscribedType, elem.Type()) + r := New(&Config{}, xep0115.New(rtr, presencesRep, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() - storage.InsertOrUpdateUser(&model.User{ + _ = userRep.UpsertUser(context.Background(), &model.User{ Username: "noelia", LastPresence: xmpp.NewPresence(j2.ToBareJID(), j2.ToBareJID(), xmpp.UnavailableType), }) - // user exists, with no presence subscription... - r.ProcessPresence(xmpp.NewPresence(j1, j2, xmpp.ProbeType)) - elem = stm.ReceiveElement() - require.Equal(t, xmpp.UnsubscribedType, elem.Type()) - - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "noelia", JID: "ortuman@jackal.im", Subscription: rostermodel.SubscriptionFrom, }) - r.ProcessPresence(xmpp.NewPresence(j1, j2, xmpp.ProbeType)) - elem = stm.ReceiveElement() + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1, j2, xmpp.ProbeType)) + elem := stm.ReceiveElement() require.Equal(t, xmpp.UnavailableType, elem.Type()) // test available presence... p2 := xmpp.NewPresence(j2, j2.ToBareJID(), xmpp.AvailableType) - storage.InsertOrUpdateUser(&model.User{ + _ = userRep.UpsertUser(context.Background(), &model.User{ Username: "noelia", LastPresence: p2, }) - r.ProcessPresence(xmpp.NewPresence(j1, j2, xmpp.ProbeType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1, j2, xmpp.ProbeType)) elem = stm.ReceiveElement() require.Equal(t, xmpp.AvailableType, elem.Type()) require.Equal(t, "noelia@jackal.im/garden", elem.From()) } func TestRoster_Subscription(t *testing.T) { - rtr, _, shutdown := setupTest("jackal.im") - defer shutdown() + rtr, userRep, presencesRep, rosterRep := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("noelia", "jackal.im", "garden", true) - r := New(&Config{}, rtr) - defer r.Shutdown() + r := New(&Config{}, xep0115.New(rtr, presencesRep, "alloc-1234"), nil, rtr, userRep, rosterRep) + defer func() { _ = r.Shutdown() }() - r.ProcessPresence(xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribeType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribeType)) time.Sleep(time.Millisecond * 150) // wait until processed... - rns, err := storage.FetchRosterNotifications("noelia") + rns, err := rosterRep.FetchRosterNotifications(context.Background(), "noelia") require.Nil(t, err) require.Equal(t, 1, len(rns)) // resend request... - r.ProcessPresence(xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribeType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribeType)) // contact request cancellation - r.ProcessPresence(xmpp.NewPresence(j2.ToBareJID(), j1.ToBareJID(), xmpp.UnsubscribedType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j2.ToBareJID(), j1.ToBareJID(), xmpp.UnsubscribedType)) time.Sleep(time.Millisecond * 150) // wait until processed... - rns, err = storage.FetchRosterNotifications("noelia") + rns, err = rosterRep.FetchRosterNotifications(context.Background(), "noelia") require.Nil(t, err) require.Equal(t, 0, len(rns)) - ri, err := storage.FetchRosterItem("ortuman", "noelia@jackal.im") + ri, err := rosterRep.FetchRosterItem(context.Background(), "ortuman", "noelia@jackal.im") require.Nil(t, err) require.Equal(t, rostermodel.SubscriptionNone, ri.Subscription) // contact accepts request... - r.ProcessPresence(xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribeType)) - r.ProcessPresence(xmpp.NewPresence(j2.ToBareJID(), j1.ToBareJID(), xmpp.SubscribedType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribeType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j2.ToBareJID(), j1.ToBareJID(), xmpp.SubscribedType)) time.Sleep(time.Millisecond * 150) // wait until processed... - ri, err = storage.FetchRosterItem("ortuman", "noelia@jackal.im") + ri, err = rosterRep.FetchRosterItem(context.Background(), "ortuman", "noelia@jackal.im") require.Nil(t, err) require.Equal(t, rostermodel.SubscriptionTo, ri.Subscription) // contact subscribes to user's presence... - r.ProcessPresence(xmpp.NewPresence(j2.ToBareJID(), j1.ToBareJID(), xmpp.SubscribeType)) - r.ProcessPresence(xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribedType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j2.ToBareJID(), j1.ToBareJID(), xmpp.SubscribeType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.SubscribedType)) time.Sleep(time.Millisecond * 150) // wait until processed... - ri, err = storage.FetchRosterItem("noelia", "ortuman@jackal.im") + ri, err = rosterRep.FetchRosterItem(context.Background(), "noelia", "ortuman@jackal.im") require.Nil(t, err) require.Equal(t, rostermodel.SubscriptionBoth, ri.Subscription) // user unsubscribes from contact's presence... - r.ProcessPresence(xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.UnsubscribeType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.UnsubscribeType)) time.Sleep(time.Millisecond * 150) // wait until processed... - ri, err = storage.FetchRosterItem("ortuman", "noelia@jackal.im") + ri, err = rosterRep.FetchRosterItem(context.Background(), "ortuman", "noelia@jackal.im") require.Nil(t, err) require.Equal(t, rostermodel.SubscriptionFrom, ri.Subscription) // user cancels contact subscription - r.ProcessPresence(xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.UnsubscribedType)) + r.ProcessPresence(context.Background(), xmpp.NewPresence(j1.ToBareJID(), j2.ToBareJID(), xmpp.UnsubscribedType)) time.Sleep(time.Millisecond * 150) // wait until processed... - ri, err = storage.FetchRosterItem("ortuman", "noelia@jackal.im") + ri, err = rosterRep.FetchRosterItem(context.Background(), "ortuman", "noelia@jackal.im") require.Nil(t, err) require.Equal(t, rostermodel.SubscriptionNone, ri.Subscription) - ri, err = storage.FetchRosterItem("noelia", "ortuman@jackal.im") + ri, err = rosterRep.FetchRosterItem(context.Background(), "noelia", "ortuman@jackal.im") require.Nil(t, err) require.Equal(t, rostermodel.SubscriptionNone, ri.Subscription) } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, repository.User, repository.Presences, repository.Roster) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + userRep := memorystorage.NewUser() + presencesRep := memorystorage.NewPresences() + rosterRep := memorystorage.NewRoster() + r, _ := router.New( + hosts, + c2srouter.New(userRep, memorystorage.NewBlockList()), + nil, + ) + return r, userRep, presencesRep, rosterRep } diff --git a/module/xep0004/field.go b/module/xep0004/field.go old mode 100644 new mode 100755 index 83cb69bec..61d57faf5 --- a/module/xep0004/field.go +++ b/module/xep0004/field.go @@ -11,6 +11,9 @@ import ( "github.com/ortuman/jackal/xmpp" ) +// FormType represents form type constant value. +const FormType = "FORM_TYPE" + const ( // Boolean represents a 'boolean' form field. Boolean = "boolean" diff --git a/module/xep0004/field_test.go b/module/xep0004/field_test.go old mode 100644 new mode 100755 diff --git a/module/xep0004/fields.go b/module/xep0004/fields.go new file mode 100755 index 000000000..26c397917 --- /dev/null +++ b/module/xep0004/fields.go @@ -0,0 +1,35 @@ +package xep0004 + +// Fields represent a set of form fields +type Fields []Field + +// ValueForField returns the associated value for a given field name. +func (f Fields) ValueForField(fieldName string) string { + return f.ValueForFieldOfType(fieldName, "") +} + +// ValuesForField returns all associated values for a given field name. +func (f Fields) ValuesForField(fieldName string) []string { + return f.ValuesForFieldOfType(fieldName, "") +} + +// ValueForFieldOfType returns the associated value for a given field name and type. +func (f Fields) ValueForFieldOfType(fieldName, typ string) string { + for _, field := range f { + if field.Var == fieldName && field.Type == typ && len(field.Values) > 0 { + return field.Values[0] + } + } + return "" +} + +// ValuesForFieldOfType returns all associated values for a given field name and type. +func (f Fields) ValuesForFieldOfType(fieldName, typ string) []string { + var res []string + for _, field := range f { + if field.Var == fieldName && field.Type == typ && len(field.Values) > 0 { + res = append(res, field.Values[0]) + } + } + return res +} diff --git a/module/xep0004/fields_test.go b/module/xep0004/fields_test.go new file mode 100755 index 000000000..7980b7e0c --- /dev/null +++ b/module/xep0004/fields_test.go @@ -0,0 +1,18 @@ +package xep0004 + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFields_ValueForField(t *testing.T) { + f := Fields{ + { + Var: "var1", + Values: []string{"foo"}, + }, + } + require.Equal(t, "foo", f.ValueForField("var1")) + require.Equal(t, "", f.ValueForField("var2")) +} diff --git a/module/xep0004/form.go b/module/xep0004/form.go old mode 100644 new mode 100755 index f8ad4a491..6503bb148 --- a/module/xep0004/form.go +++ b/module/xep0004/form.go @@ -11,7 +11,8 @@ import ( "github.com/ortuman/jackal/xmpp" ) -const formNamespace = "jabber:x:data" +// FormNamespace specifies XEP-0004 namespace constant value. +const FormNamespace = "jabber:x:data" const ( // Form represents a 'form' data form. @@ -33,18 +34,17 @@ type DataForm struct { Type string Title string Instructions string - Fields []Field - Reported []Field - Items [][]Field + Fields Fields + Reported Fields + Items []Fields } -// NewFormFromElement returns a new data form entity reading it -// from it's XMPP representation. +// NewFormFromElement returns a new data form entity reading it from it's XMPP representation. func NewFormFromElement(elem xmpp.XElement) (*DataForm, error) { if n := elem.Name(); n != "x" { return nil, fmt.Errorf("invalid form name: %s", n) } - if ns := elem.Namespace(); ns != formNamespace { + if ns := elem.Namespace(); ns != FormNamespace { return nil, fmt.Errorf("invalid form namespace: %s", ns) } typ := elem.Attributes().Get("type") @@ -84,7 +84,7 @@ func NewFormFromElement(elem xmpp.XElement) (*DataForm, error) { // Element returns data form XMPP representation. func (f *DataForm) Element() xmpp.XElement { - elem := xmpp.NewElementNamespace("x", formNamespace) + elem := xmpp.NewElementNamespace("x", FormNamespace) if len(f.Title) > 0 { titleElem := xmpp.NewElementName("title") titleElem.SetText(f.Title) diff --git a/module/xep0004/form_test.go b/module/xep0004/form_test.go old mode 100644 new mode 100755 index 56a3ab58f..edb467fb1 --- a/module/xep0004/form_test.go +++ b/module/xep0004/form_test.go @@ -21,7 +21,7 @@ func TestDataForm_FromElement(t *testing.T) { _, err = NewFormFromElement(elem) require.NotNil(t, err) - elem.SetNamespace(formNamespace) + elem.SetNamespace(FormNamespace) _, err = NewFormFromElement(elem) require.NotNil(t, err) @@ -66,7 +66,7 @@ func TestDataForm_Element(t *testing.T) { form.Type = Form elem := form.Element() require.Equal(t, "x", elem.Name()) - require.Equal(t, formNamespace, elem.Namespace()) + require.Equal(t, FormNamespace, elem.Namespace()) form.Title = "A title" form.Instructions = "A set of instructions" @@ -80,7 +80,7 @@ func TestDataForm_Element(t *testing.T) { require.Equal(t, "A set of instructions", instElem.Text()) form.Reported = []Field{{Var: "var1"}} - form.Items = [][]Field{{{Var: "var2"}}} + form.Items = []Fields{{{Var: "var2"}}} elem = form.Element() require.NotNil(t, elem.Elements().Child("reported")) diff --git a/module/xep0012/last_activity.go b/module/xep0012/last_activity.go index 2ea0aed1f..5ce89f239 100644 --- a/module/xep0012/last_activity.go +++ b/module/xep0012/last_activity.go @@ -6,16 +6,16 @@ package xep0012 import ( + "context" "strconv" "time" - "github.com/ortuman/jackal/runqueue" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/module/xep0030" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) @@ -24,17 +24,21 @@ const lastActivityNamespace = "jabber:iq:last" // LastActivity represents a last activity stream module. type LastActivity struct { - router *router.Router + router router.Router + userRep repository.User + rosterRep repository.Roster startTime time.Time runQueue *runqueue.RunQueue } // New returns a last activity IQ handler module. -func New(disco *xep0030.DiscoInfo, router *router.Router) *LastActivity { +func New(disco *xep0030.DiscoInfo, router router.Router, userRep repository.User, rosterRep repository.Roster) *LastActivity { x := &LastActivity{ + runQueue: runqueue.New("xep0012"), router: router, + userRep: userRep, + rosterRep: rosterRep, startTime: time.Now(), - runQueue: runqueue.New("xep0012"), } if disco != nil { disco.RegisterServerFeature(lastActivityNamespace) @@ -49,9 +53,9 @@ func (x *LastActivity) MatchesIQ(iq *xmpp.IQ) bool { } // ProcessIQ processes a last activity IQ taking according actions over the associated stream. -func (x *LastActivity) ProcessIQ(iq *xmpp.IQ) { +func (x *LastActivity) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - x.processIQ(iq) + x.processIQ(ctx, iq) }) } @@ -63,46 +67,46 @@ func (x *LastActivity) Shutdown() error { return nil } -func (x *LastActivity) processIQ(iq *xmpp.IQ) { +func (x *LastActivity) processIQ(ctx context.Context, iq *xmpp.IQ) { fromJID := iq.FromJID() toJID := iq.ToJID() if toJID.IsServer() { - x.sendServerUptime(iq) + x.sendServerUptime(ctx, iq) } else if toJID.IsBare() { - ok, err := x.isSubscribedTo(toJID, fromJID) + ok, err := x.isSubscribedTo(ctx, toJID, fromJID) if err != nil { log.Error(err) - _ = x.router.Route(iq.InternalServerError()) + _ = x.router.Route(ctx, iq.InternalServerError()) return } if ok { - x.sendUserLastActivity(iq, toJID) + x.sendUserLastActivity(ctx, iq, toJID) } else { - _ = x.router.Route(iq.ForbiddenError()) + _ = x.router.Route(ctx, iq.ForbiddenError()) } } else { - _ = x.router.Route(iq.BadRequestError()) + _ = x.router.Route(ctx, iq.BadRequestError()) } } -func (x *LastActivity) sendServerUptime(iq *xmpp.IQ) { +func (x *LastActivity) sendServerUptime(ctx context.Context, iq *xmpp.IQ) { secs := int(time.Duration(time.Now().UnixNano()-x.startTime.UnixNano()) / time.Second) - x.sendReply(iq, secs, "") + x.sendReply(ctx, iq, secs, "") } -func (x *LastActivity) sendUserLastActivity(iq *xmpp.IQ, to *jid.JID) { - if len(x.router.UserStreams(to.Node())) > 0 { // user is online - x.sendReply(iq, 0, "") +func (x *LastActivity) sendUserLastActivity(ctx context.Context, iq *xmpp.IQ, to *jid.JID) { + if len(x.router.LocalStreams(to.Node())) > 0 { // user is online + x.sendReply(ctx, iq, 0, "") return } - usr, err := storage.FetchUser(to.Node()) + usr, err := x.userRep.FetchUser(ctx, to.Node()) if err != nil { log.Error(err) - _ = x.router.Route(iq.InternalServerError()) + _ = x.router.Route(ctx, iq.InternalServerError()) return } if usr == nil { - _ = x.router.Route(iq.ItemNotFoundError()) + _ = x.router.Route(ctx, iq.ItemNotFoundError()) return } var secs int @@ -113,23 +117,23 @@ func (x *LastActivity) sendUserLastActivity(iq *xmpp.IQ, to *jid.JID) { status = st.Text() } } - x.sendReply(iq, secs, status) + x.sendReply(ctx, iq, secs, status) } -func (x *LastActivity) sendReply(iq *xmpp.IQ, secs int, status string) { +func (x *LastActivity) sendReply(ctx context.Context, iq *xmpp.IQ, secs int, status string) { q := xmpp.NewElementNamespace("query", lastActivityNamespace) q.SetText(status) q.SetAttribute("seconds", strconv.Itoa(secs)) res := iq.ResultIQ() res.AppendElement(q) - _ = x.router.Route(res) + _ = x.router.Route(ctx, res) } -func (x *LastActivity) isSubscribedTo(contact *jid.JID, userJID *jid.JID) (bool, error) { - if contact.Matches(userJID, jid.MatchesBare) { +func (x *LastActivity) isSubscribedTo(ctx context.Context, contact *jid.JID, userJID *jid.JID) (bool, error) { + if contact.MatchesWithOptions(userJID, jid.MatchesBare) { return true, nil } - ri, err := storage.FetchRosterItem(userJID.Node(), contact.ToBareJID().String()) + ri, err := x.rosterRep.FetchRosterItem(ctx, userJID.Node(), contact.ToBareJID().String()) if err != nil { return false, err } diff --git a/module/xep0012/last_activity_test.go b/module/xep0012/last_activity_test.go index 1b5283e33..9605b03d5 100644 --- a/module/xep0012/last_activity_test.go +++ b/module/xep0012/last_activity_test.go @@ -6,14 +6,18 @@ package xep0012 import ( + "context" "crypto/tls" "testing" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -22,13 +26,12 @@ import ( ) func TestXEP0012_Matching(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, userRep, rosterRep := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) - x := New(nil, r) - defer x.Shutdown() + x := New(nil, r, userRep, rosterRep) + defer func() { _ = x.Shutdown() }() // test MatchesIQ iq1 := xmpp.NewIQType(uuid.New(), xmpp.GetType) @@ -53,26 +56,27 @@ func TestXEP0012_Matching(t *testing.T) { } func TestXEP0012_GetServerLastActivity(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, userRep, rosterRep := setupTest("jackal.im") j1, _ := jid.New("", "jackal.im", "", true) j2, _ := jid.New("ortuman", "jackal.im", "garden", true) stm := stream.NewMockC2S("abcd", j2) - defer stm.Disconnect(nil) + stm.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + + defer stm.Disconnect(context.Background(), nil) - x := New(nil, r) - defer x.Shutdown() + x := New(nil, r, userRep, rosterRep) + defer func() { _ = x.Shutdown() }() - r.Bind(stm) + r.Bind(context.Background(), stm) iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq.SetFromJID(j2) iq.SetToJID(j1) iq.AppendElement(xmpp.NewElementNamespace("query", lastActivityNamespace)) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() q := elem.Elements().Child("query") require.NotNil(t, q) @@ -81,25 +85,28 @@ func TestXEP0012_GetServerLastActivity(t *testing.T) { } func TestXEP0012_GetOnlineUserLastActivity(t *testing.T) { - r, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, userRep, rosterRep := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("noelia", "jackal.im", "garden", true) + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2 := stream.NewMockC2S(uuid.New(), j2) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) - x := New(nil, r) - defer x.Shutdown() + x := New(nil, r, userRep, rosterRep) + defer func() { _ = x.Shutdown() }() - r.Bind(stm1) + r.Bind(context.Background(), stm1) iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq.SetFromJID(j1) iq.SetToJID(j2.ToBareJID()) iq.AppendElement(xmpp.NewElementNamespace("query", lastActivityNamespace)) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm1.ReceiveElement() require.Equal(t, xmpp.ErrForbidden.Error(), elem.Error().Elements().All()[0].Name()) @@ -108,44 +115,46 @@ func TestXEP0012_GetOnlineUserLastActivity(t *testing.T) { st.SetText("Gone!") p.AppendElement(st) - storage.InsertOrUpdateUser(&model.User{ + _ = userRep.UpsertUser(context.Background(), &model.User{ Username: "noelia", LastPresence: p, }) - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "ortuman", JID: "noelia@jackal.im", Subscription: "both", }) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm1.ReceiveElement() q := elem.Elements().ChildNamespace("query", lastActivityNamespace) secs := q.Attributes().Get("seconds") require.True(t, len(secs) > 0) // set as online - r.Bind(stm2) + r.Bind(context.Background(), stm2) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm1.ReceiveElement() q = elem.Elements().ChildNamespace("query", lastActivityNamespace) secs = q.Attributes().Get("seconds") require.Equal(t, "0", secs) - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem = stm1.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, repository.User, repository.Roster) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + userRep := memorystorage.NewUser() + rosterRep := memorystorage.NewRoster() + r, _ := router.New( + hosts, + c2srouter.New(userRep, memorystorage.NewBlockList()), + nil, + ) + return r, userRep, rosterRep } diff --git a/module/xep0030/disco_info.go b/module/xep0030/disco_info.go index 23b3c5a4e..9fdadfc39 100644 --- a/module/xep0030/disco_info.go +++ b/module/xep0030/disco_info.go @@ -6,10 +6,12 @@ package xep0030 import ( + "context" "sync" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) @@ -22,19 +24,22 @@ const ( // DiscoInfo represents a disco info server stream module. type DiscoInfo struct { mu sync.RWMutex - router *router.Router + router router.Router srvProvider *serverProvider providers map[string]InfoProvider runQueue *runqueue.RunQueue } // New returns a disco info IQ handler module. -func New(router *router.Router) *DiscoInfo { +func New(router router.Router, rosterRep repository.Roster) *DiscoInfo { di := &DiscoInfo{ - router: router, - srvProvider: &serverProvider{router: router}, - providers: make(map[string]InfoProvider), - runQueue: runqueue.New("xep0030"), + router: router, + srvProvider: &serverProvider{ + router: router, + rosterRep: rosterRep, + }, + providers: make(map[string]InfoProvider), + runQueue: runqueue.New("xep0030"), } di.RegisterServerFeature(discoItemsNamespace) di.RegisterServerFeature(discoInfoNamespace) @@ -97,11 +102,10 @@ func (x *DiscoInfo) MatchesIQ(iq *xmpp.IQ) bool { return iq.IsGet() && (q.Namespace() == discoInfoNamespace || q.Namespace() == discoItemsNamespace) } -// ProcessIQ processes a disco info IQ taking according actions -// over the associated stream. -func (x *DiscoInfo) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes a disco info IQ taking according actions over the associated stream. +func (x *DiscoInfo) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - x.processIQ(iq) + x.processIQ(ctx, iq) }) } @@ -113,52 +117,52 @@ func (x *DiscoInfo) Shutdown() error { return nil } -func (x *DiscoInfo) processIQ(iq *xmpp.IQ) { +func (x *DiscoInfo) processIQ(ctx context.Context, iq *xmpp.IQ) { fromJID := iq.FromJID() toJID := iq.ToJID() var prov InfoProvider - if x.router.IsLocalHost(toJID.Domain()) { - prov = x.srvProvider + if x.router.Hosts().IsLocalHost(toJID.Domain()) { + if p := x.providers[toJID.String()]; p != nil { + prov = p + } else { + prov = x.srvProvider + } } else { prov = x.providers[toJID.Domain()] if prov == nil { - _ = x.router.Route(iq.ItemNotFoundError()) + _ = x.router.Route(ctx, iq.ItemNotFoundError()) return } } - if prov == nil { - _ = x.router.Route(iq.ItemNotFoundError()) - return - } q := iq.Elements().Child("query") node := q.Attributes().Get("node") if q != nil { switch q.Namespace() { case discoInfoNamespace: - x.sendDiscoInfo(prov, toJID, fromJID, node, iq) + x.sendDiscoInfo(ctx, prov, toJID, fromJID, node, iq) return case discoItemsNamespace: - x.sendDiscoItems(prov, toJID, fromJID, node, iq) + x.sendDiscoItems(ctx, prov, toJID, fromJID, node, iq) return } } - _ = x.router.Route(iq.BadRequestError()) + _ = x.router.Route(ctx, iq.BadRequestError()) } -func (x *DiscoInfo) sendDiscoInfo(prov InfoProvider, toJID, fromJID *jid.JID, node string, iq *xmpp.IQ) { - features, sErr := prov.Features(toJID, fromJID, node) +func (x *DiscoInfo) sendDiscoInfo(ctx context.Context, prov InfoProvider, toJID, fromJID *jid.JID, node string, iq *xmpp.IQ) { + features, sErr := prov.Features(ctx, toJID, fromJID, node) if sErr != nil { - _ = x.router.Route(xmpp.NewErrorStanzaFromStanza(iq, sErr, nil)) + _ = x.router.Route(ctx, xmpp.NewErrorStanzaFromStanza(iq, sErr, nil)) return } else if len(features) == 0 { - _ = x.router.Route(iq.ItemNotFoundError()) + _ = x.router.Route(ctx, iq.ItemNotFoundError()) return } result := iq.ResultIQ() query := xmpp.NewElementNamespace("query", discoInfoNamespace) - identities := prov.Identities(toJID, fromJID, node) + identities := prov.Identities(ctx, toJID, fromJID, node) for _, identity := range identities { identityEl := xmpp.NewElementName("identity") identityEl.SetAttribute("category", identity.Category) @@ -175,22 +179,22 @@ func (x *DiscoInfo) sendDiscoInfo(prov InfoProvider, toJID, fromJID *jid.JID, no featureEl.SetAttribute("var", feature) query.AppendElement(featureEl) } - form, sErr := prov.Form(toJID, fromJID, node) + form, sErr := prov.Form(ctx, toJID, fromJID, node) if sErr != nil { - _ = x.router.Route(xmpp.NewErrorStanzaFromStanza(iq, sErr, nil)) + _ = x.router.Route(ctx, xmpp.NewErrorStanzaFromStanza(iq, sErr, nil)) return } if form != nil { query.AppendElement(form.Element()) } result.AppendElement(query) - _ = x.router.Route(result) + _ = x.router.Route(ctx, result) } -func (x *DiscoInfo) sendDiscoItems(prov InfoProvider, toJID, fromJID *jid.JID, node string, iq *xmpp.IQ) { - items, sErr := prov.Items(toJID, fromJID, node) +func (x *DiscoInfo) sendDiscoItems(ctx context.Context, prov InfoProvider, toJID, fromJID *jid.JID, node string, iq *xmpp.IQ) { + items, sErr := prov.Items(ctx, toJID, fromJID, node) if sErr != nil { - _ = x.router.Route(xmpp.NewErrorStanzaFromStanza(iq, sErr, nil)) + _ = x.router.Route(ctx, xmpp.NewErrorStanzaFromStanza(iq, sErr, nil)) return } result := iq.ResultIQ() @@ -207,5 +211,5 @@ func (x *DiscoInfo) sendDiscoItems(prov InfoProvider, toJID, fromJID *jid.JID, n query.AppendElement(itemEl) } result.AppendElement(query) - _ = x.router.Route(result) + _ = x.router.Route(ctx, result) } diff --git a/module/xep0030/disco_info_test.go b/module/xep0030/disco_info_test.go index f2f69057b..e4b649dfe 100644 --- a/module/xep0030/disco_info_test.go +++ b/module/xep0030/disco_info_test.go @@ -6,13 +6,17 @@ package xep0030 import ( + "context" "crypto/tls" "testing" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/module/xep0004" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -21,13 +25,12 @@ import ( ) func TestXEP0030_Matching(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, rosterRep := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) - x := New(r) - defer x.Shutdown() + x := New(r, rosterRep) + defer func() { _ = x.Shutdown() }() // test MatchesIQ iq1 := xmpp.NewIQType(uuid.New(), xmpp.GetType) @@ -54,17 +57,18 @@ func TestXEP0030_Matching(t *testing.T) { } func TestXEP0030_SendFeatures(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, rosterRep := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) srvJid, _ := jid.New("", "jackal.im", "", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) - x := New(r) - defer x.Shutdown() + x := New(r, rosterRep) + defer func() { _ = x.Shutdown() }() x.RegisterServerFeature("s0") x.RegisterServerFeature("s1") @@ -77,7 +81,7 @@ func TestXEP0030_SendFeatures(t *testing.T) { iq1.SetToJID(srvJid) iq1.AppendElement(xmpp.NewElementNamespace("query", discoInfoNamespace)) - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem := stm.ReceiveElement() require.NotNil(t, elem) q := elem.Elements().ChildNamespace("query", discoInfoNamespace) @@ -90,7 +94,7 @@ func TestXEP0030_SendFeatures(t *testing.T) { x.UnregisterServerFeature("s1") x.UnregisterAccountFeature("af1") - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem = stm.ReceiveElement() q = elem.Elements().ChildNamespace("query", discoInfoNamespace) @@ -98,7 +102,7 @@ func TestXEP0030_SendFeatures(t *testing.T) { require.Equal(t, 5, q.Elements().Count()) iq1.SetToJID(j.ToBareJID()) - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem = stm.ReceiveElement() q = elem.Elements().ChildNamespace("query", discoInfoNamespace) @@ -107,23 +111,24 @@ func TestXEP0030_SendFeatures(t *testing.T) { } func TestXEP0030_SendItems(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, rosterRep := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) - x := New(r) - defer x.Shutdown() + r.Bind(context.Background(), stm) + + x := New(r, rosterRep) + defer func() { _ = x.Shutdown() }() iq1 := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq1.SetFromJID(j) iq1.SetToJID(j.ToBareJID()) iq1.AppendElement(xmpp.NewElementNamespace("query", discoItemsNamespace)) - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem := stm.ReceiveElement() require.NotNil(t, elem) q := elem.Elements().ChildNamespace("query", discoItemsNamespace) @@ -134,48 +139,49 @@ func TestXEP0030_SendItems(t *testing.T) { type testDiscoInfoProvider struct { } -func (tp *testDiscoInfoProvider) Identities(toJID, fromJID *jid.JID, node string) []Identity { +func (tp *testDiscoInfoProvider) Identities(_ context.Context, toJID, fromJID *jid.JID, node string) []Identity { return []Identity{{Name: "test_identity"}} } -func (tp *testDiscoInfoProvider) Items(toJID, fromJID *jid.JID, node string) ([]Item, *xmpp.StanzaError) { +func (tp *testDiscoInfoProvider) Items(_ context.Context, toJID, fromJID *jid.JID, node string) ([]Item, *xmpp.StanzaError) { return []Item{{Jid: "test.jackal.im"}}, nil } -func (tp *testDiscoInfoProvider) Features(toJID, fromJID *jid.JID, node string) ([]Feature, *xmpp.StanzaError) { +func (tp *testDiscoInfoProvider) Features(_ context.Context, toJID, fromJID *jid.JID, node string) ([]Feature, *xmpp.StanzaError) { return []Feature{"com.jackal.im.feature"}, nil } -func (tp *testDiscoInfoProvider) Form(toJID, fromJID *jid.JID, node string) (*xep0004.DataForm, *xmpp.StanzaError) { +func (tp *testDiscoInfoProvider) Form(_ context.Context, toJID, fromJID *jid.JID, node string) (*xep0004.DataForm, *xmpp.StanzaError) { return nil, nil } func TestXEP0030_Provider(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, rosterRep := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) compJID, _ := jid.New("", "test.jackal.im", "", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) - x := New(r) - defer x.Shutdown() + x := New(r, rosterRep) + defer func() { _ = x.Shutdown() }() iq1 := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq1.SetFromJID(j) iq1.SetToJID(compJID) iq1.AppendElement(xmpp.NewElementNamespace("query", discoItemsNamespace)) - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem := stm.ReceiveElement() require.True(t, elem.IsError()) require.Equal(t, xmpp.ErrItemNotFound.Error(), elem.Error().Elements().All()[0].Name()) x.RegisterProvider(compJID.String(), &testDiscoInfoProvider{}) - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem = stm.ReceiveElement() q := elem.Elements().ChildNamespace("query", discoItemsNamespace) require.NotNil(t, q) @@ -184,19 +190,19 @@ func TestXEP0030_Provider(t *testing.T) { x.UnregisterProvider(compJID.String()) - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem = stm.ReceiveElement() require.True(t, elem.IsError()) require.Equal(t, xmpp.ErrItemNotFound.Error(), elem.Error().Elements().All()[0].Name()) } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, repository.Roster) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + rosterRep := memorystorage.NewRoster() + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r, rosterRep } diff --git a/module/xep0030/infoprovider.go b/module/xep0030/infoprovider.go index e68dc30bd..118001b17 100644 --- a/module/xep0030/infoprovider.go +++ b/module/xep0030/infoprovider.go @@ -6,6 +6,8 @@ package xep0030 import ( + "context" + "github.com/ortuman/jackal/module/xep0004" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -31,17 +33,17 @@ type Item struct { // InfoProvider represents a generic disco info domain provider. type InfoProvider interface { // Identities returns all identities associated to the provider. - Identities(toJID, fromJID *jid.JID, node string) []Identity + Identities(ctx context.Context, toJID, fromJID *jid.JID, node string) []Identity // Items returns all items associated to the provider. // A proper stanza error should be returned in case an error occurs. - Items(toJID, fromJID *jid.JID, node string) ([]Item, *xmpp.StanzaError) + Items(ctx context.Context, toJID, fromJID *jid.JID, node string) ([]Item, *xmpp.StanzaError) // Features returns all features associated to the provider. // A proper stanza error should be returned in case an error occurs. - Features(toJID, fromJID *jid.JID, node string) ([]Feature, *xmpp.StanzaError) + Features(ctx context.Context, toJID, fromJID *jid.JID, node string) ([]Feature, *xmpp.StanzaError) - // Form returns the data form associated to the provider. + // ResultForm returns the data form associated to the provider. // A proper stanza error should be returned in case an error occurs. - Form(toJID, fromJID *jid.JID, node string) (*xep0004.DataForm, *xmpp.StanzaError) + Form(ctx context.Context, toJID, fromJID *jid.JID, node string) (*xep0004.DataForm, *xmpp.StanzaError) } diff --git a/module/xep0030/server_provider.go b/module/xep0030/server_provider.go index e1a8a81f9..329a5672e 100644 --- a/module/xep0030/server_provider.go +++ b/module/xep0030/server_provider.go @@ -6,26 +6,28 @@ package xep0030 import ( + "context" "sync" "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/module/xep0004" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) type serverProvider struct { - router *router.Router + router router.Router + rosterRep repository.Roster mu sync.RWMutex serverItems []Item serverFeatures []Feature accountFeatures []Feature } -func (sp *serverProvider) Identities(toJID, fromJID *jid.JID, node string) []Identity { +func (sp *serverProvider) Identities(_ context.Context, toJID, _ *jid.JID, node string) []Identity { if node != "" { return nil } @@ -35,29 +37,29 @@ func (sp *serverProvider) Identities(toJID, fromJID *jid.JID, node string) []Ide return []Identity{{Type: "registered", Category: "account"}} } -func (sp *serverProvider) Items(toJID, fromJID *jid.JID, node string) ([]Item, *xmpp.StanzaError) { +func (sp *serverProvider) Items(ctx context.Context, toJID, fromJID *jid.JID, node string) ([]Item, *xmpp.StanzaError) { if node != "" { return nil, nil } - var itms []Item + var items []Item if toJID.IsServer() { - itms = append(itms, Item{Jid: fromJID.ToBareJID().String()}) - itms = append(itms, sp.serverItems...) + items = append(items, Item{Jid: fromJID.ToBareJID().String()}) + items = append(items, sp.serverItems...) } else { // add account resources - if sp.isSubscribedTo(toJID, fromJID) { - stms := sp.router.UserStreams(toJID.Node()) - for _, stm := range stms { - itms = append(itms, Item{Jid: stm.JID().String()}) + if sp.isSubscribedTo(ctx, toJID, fromJID) { + streams := sp.router.LocalStreams(toJID.Node()) + for _, stm := range streams { + items = append(items, Item{Jid: stm.JID().String()}) } } else { return nil, xmpp.ErrSubscriptionRequired } } - return itms, nil + return items, nil } -func (sp *serverProvider) Features(toJID, fromJID *jid.JID, node string) ([]Feature, *xmpp.StanzaError) { +func (sp *serverProvider) Features(ctx context.Context, toJID, fromJID *jid.JID, node string) ([]Feature, *xmpp.StanzaError) { sp.mu.RLock() defer sp.mu.RUnlock() if node != "" { @@ -66,13 +68,13 @@ func (sp *serverProvider) Features(toJID, fromJID *jid.JID, node string) ([]Feat if toJID.IsServer() { return sp.serverFeatures, nil } - if sp.isSubscribedTo(toJID, fromJID) { + if sp.isSubscribedTo(ctx, toJID, fromJID) { return sp.accountFeatures, nil } return nil, xmpp.ErrSubscriptionRequired } -func (sp *serverProvider) Form(toJID, fromJID *jid.JID, node string) (*xep0004.DataForm, *xmpp.StanzaError) { +func (sp *serverProvider) Form(_ context.Context, _, _ *jid.JID, _ string) (*xep0004.DataForm, *xmpp.StanzaError) { return nil, nil } @@ -142,11 +144,11 @@ func (sp *serverProvider) unregisterAccountFeature(feature Feature) { } } -func (sp *serverProvider) isSubscribedTo(contact *jid.JID, userJID *jid.JID) bool { - if contact.Matches(userJID, jid.MatchesBare) { +func (sp *serverProvider) isSubscribedTo(ctx context.Context, contact *jid.JID, userJID *jid.JID) bool { + if contact.MatchesWithOptions(userJID, jid.MatchesBare) { return true } - ri, err := storage.FetchRosterItem(userJID.Node(), contact.ToBareJID().String()) + ri, err := sp.rosterRep.FetchRosterItem(ctx, userJID.Node(), contact.ToBareJID().String()) if err != nil { log.Error(err) return false diff --git a/module/xep0030/server_provider_test.go b/module/xep0030/server_provider_test.go index 59fc02e16..5b955587b 100644 --- a/module/xep0030/server_provider_test.go +++ b/module/xep0030/server_provider_test.go @@ -6,11 +6,11 @@ package xep0030 import ( + "context" "sort" "testing" - "github.com/ortuman/jackal/model/rostermodel" - "github.com/ortuman/jackal/storage" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -19,7 +19,10 @@ import ( ) func TestServerProvider_Features(t *testing.T) { - var sp serverProvider + r, rosterRep := setupTest("jackal.im") + + var sp = serverProvider{router: r, rosterRep: rosterRep} + sp.registerServerFeature("sf0") sp.registerServerFeature("sf1") sp.registerServerFeature("sf1") @@ -38,44 +41,44 @@ func TestServerProvider_Features(t *testing.T) { accJID, _ := jid.New("ortuman", "jackal.im", "garden", true) accJID2, _ := jid.New("noelia", "jackal.im", "balcony", true) - features, sErr := sp.Features(srvJID, accJID, "node") + features, sErr := sp.Features(context.Background(), srvJID, accJID, "node") require.Nil(t, features) require.Nil(t, sErr) - features, sErr = sp.Features(srvJID, accJID, "") + features, sErr = sp.Features(context.Background(), srvJID, accJID, "") require.Equal(t, features, []Feature{"sf0"}) require.Nil(t, sErr) - features, sErr = sp.Features(accJID.ToBareJID(), accJID, "") + features, sErr = sp.Features(context.Background(), accJID.ToBareJID(), accJID, "") require.Equal(t, features, []Feature{"af1"}) require.Nil(t, sErr) - features, sErr = sp.Features(accJID2.ToBareJID(), accJID, "") + features, sErr = sp.Features(context.Background(), accJID2.ToBareJID(), accJID, "") require.Nil(t, features) require.Equal(t, sErr, xmpp.ErrSubscriptionRequired) } func TestServerProvider_Identities(t *testing.T) { - var sp serverProvider + r, rosterRep := setupTest("jackal.im") + + var sp = serverProvider{router: r, rosterRep: rosterRep} srvJID, _ := jid.New("", "jackal.im", "", true) accJID, _ := jid.New("ortuman", "jackal.im", "garden", true) - require.Nil(t, sp.Identities(srvJID, accJID, "node")) + require.Nil(t, sp.Identities(context.Background(), srvJID, accJID, "node")) - require.Equal(t, sp.Identities(srvJID, accJID, ""), []Identity{ + require.Equal(t, sp.Identities(context.Background(), srvJID, accJID, ""), []Identity{ {Type: "im", Category: "server", Name: "jackal"}, }) - require.Equal(t, sp.Identities(accJID.ToBareJID(), accJID, ""), []Identity{ + require.Equal(t, sp.Identities(context.Background(), accJID.ToBareJID(), accJID, ""), []Identity{ {Type: "registered", Category: "account"}, }) } func TestServerProvider_Items(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, rosterRep := setupTest("jackal.im") - var sp serverProvider - sp.router = r + var sp = serverProvider{router: r, rosterRep: rosterRep} srvJID, _ := jid.New("", "jackal.im", "", true) accJID1, _ := jid.New("ortuman", "jackal.im", "garden", true) @@ -86,30 +89,30 @@ func TestServerProvider_Items(t *testing.T) { stm2 := stream.NewMockC2S(uuid.New(), accJID2) stm3 := stream.NewMockC2S(uuid.New(), accJID3) - r.Bind(stm1) - r.Bind(stm2) - r.Bind(stm3) + r.Bind(context.Background(), stm1) + r.Bind(context.Background(), stm2) + r.Bind(context.Background(), stm3) - items, sErr := sp.Items(srvJID, accJID1, "node") + items, sErr := sp.Items(context.Background(), srvJID, accJID1, "node") require.Nil(t, items) require.Nil(t, sErr) - items, sErr = sp.Items(srvJID, accJID1, "") + items, sErr = sp.Items(context.Background(), srvJID, accJID1, "") require.Equal(t, items, []Item{ {Jid: accJID1.ToBareJID().String()}, }) require.Nil(t, sErr) - items, sErr = sp.Items(accJID2.ToBareJID(), accJID1, "") + items, sErr = sp.Items(context.Background(), accJID2.ToBareJID(), accJID1, "") require.Nil(t, items) require.Equal(t, sErr, xmpp.ErrSubscriptionRequired) - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "ortuman", JID: "noelia@jackal.im", Subscription: "both", }) - items, sErr = sp.Items(accJID2.ToBareJID(), accJID1, "") + items, sErr = sp.Items(context.Background(), accJID2.ToBareJID(), accJID1, "") sort.Slice(items, func(i, j int) bool { return items[i].Jid < items[j].Jid }) require.Equal(t, items, []Item{ diff --git a/module/xep0045/config.go b/module/xep0045/config.go new file mode 100644 index 000000000..4d03a6f06 --- /dev/null +++ b/module/xep0045/config.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/pkg/errors" +) + +const ( + defaultServiceName = "Chatroom Server" +) + +// Config represents XEP-0045 Multi-User Chat configuration +type Config struct { + MucHost string + Name string + RoomDefaults mucmodel.RoomConfig +} + +type configProxy struct { + MucHost string `yaml:"host"` + Name string `yaml:"name"` + RoomDefaults mucmodel.RoomConfig `yaml:"room_defaults"` +} + +// UnmarshalYAML satisfies Unmarshaler interface. +func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + p := configProxy{} + if err := unmarshal(&p); err != nil { + return err + } + cfg.MucHost = p.MucHost + if len(cfg.MucHost) == 0 { + return errors.New("muc: must specify a service hostname") + } + cfg.Name = p.Name + if len(cfg.Name) == 0 { + cfg.Name = defaultServiceName + } + cfg.RoomDefaults = p.RoomDefaults + return nil +} diff --git a/module/xep0045/config_test.go b/module/xep0045/config_test.go new file mode 100644 index 000000000..a50529d6c --- /dev/null +++ b/module/xep0045/config_test.go @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "crypto/tls" + "testing" + + c2srouter "github.com/ortuman/jackal/c2s/router" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/router/host" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +const cfgExample = ` +host: conference.localhost +name: "Test Server" +` + +// mockMucService is a mock structure for testing xep-0045 +type mockMucService struct { + muc *Muc + room *mucmodel.Room + owner *mucmodel.Occupant + ownerFullJID *jid.JID + ownerStm *stream.MockC2S + occ *mucmodel.Occupant + occFullJID *jid.JID + occStm *stream.MockC2S +} + +func TestXEP0045_MucConfig(t *testing.T) { + badCfg := `host:` + cfg := &Config{} + err := yaml.Unmarshal([]byte(badCfg), &cfg) + require.NotNil(t, err) + + goodCfg := cfgExample + cfg = &Config{} + err = yaml.Unmarshal([]byte(goodCfg), &cfg) + require.Nil(t, err) + require.Equal(t, cfg.MucHost, "conference.localhost") + require.Equal(t, cfg.Name, "Test Server") + require.NotNil(t, cfg.RoomDefaults) +} + +// setupTest returns router and container instance used to setup the mock muc service +func setupTest(domain string) (router.Router, repository.Container) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + rep, _ := memorystorage.New() + r, _ := router.New( + hosts, + c2srouter.New(rep.User(), memorystorage.NewBlockList()), + nil, + ) + return r, rep +} + +// setupMockMucService returns a Muc Service instance, without any rooms +func setupMockMucService() *mockMucService { + r, rep := setupTest("jackal.im") + muc := New(&Config{MucHost: "conference.jackal.im"}, nil, r, rep.Room(), rep.Occupant()) + return &mockMucService{muc: muc} +} + +// setupTestRoom returns a Muc Service instance, with a room +func setupTestRoom() *mockMucService { + mock := setupMockMucService() + roomConfig := &mucmodel.RoomConfig{ + Open: true, + MaxOccCnt: -1, + } + roomJID, _ := jid.New("room", "conference.jackal.im", "", true) + room := &mucmodel.Room{ + Config: roomConfig, + RoomJID: roomJID, + } + mock.muc.repRoom.UpsertRoom(nil, room) + mock.room = room + return mock +} + +// setupTestRoomAndOwner returns a Muc Service instance, with a room and the room owner +func setupTestRoomAndOwner() *mockMucService { + mock := setupTestRoom() + + ownerUserJID, _ := jid.New("milos", "jackal.im", "phone", true) + ownerOccJID, _ := jid.New("room", "conference.jackal.im", "owner", true) + owner, _ := mucmodel.NewOccupant(ownerOccJID, ownerUserJID.ToBareJID()) + owner.AddResource(ownerUserJID.Resource()) + owner.SetAffiliation("owner") + owner.SetRole("moderator") + mock.muc.AddOccupantToRoom(nil, mock.room, owner) + + ownerStm := stream.NewMockC2S("id-1", ownerUserJID) + ownerStm.SetPresence(xmpp.NewPresence(owner.BareJID, ownerUserJID, xmpp.AvailableType)) + mock.muc.router.Bind(context.Background(), ownerStm) + + mock.owner = owner + mock.ownerStm = ownerStm + mock.ownerFullJID = ownerUserJID + return mock +} + +// setupTestRoomAndOwnerAndOcc returns a Muc Service instance, with a room, owner and an occupant +func setupTestRoomAndOwnerAndOcc() *mockMucService { + mock := setupTestRoomAndOwner() + + occUserJID, _ := jid.New("ortuman", "jackal.im", "balcony", true) + occOccJID, _ := jid.New("room", "conference.jackal.im", "occ", true) + occ, _ := mucmodel.NewOccupant(occOccJID, occUserJID.ToBareJID()) + occ.AddResource(occUserJID.Resource()) + occ.SetAffiliation("") + occ.SetRole("") + mock.muc.AddOccupantToRoom(nil, mock.room, occ) + + occStm := stream.NewMockC2S("id-1", occUserJID) + occStm.SetPresence(xmpp.NewPresence(occ.BareJID, occUserJID, xmpp.AvailableType)) + mock.muc.router.Bind(context.Background(), occStm) + + mock.occ = occ + mock.occStm = occStm + mock.occFullJID = occUserJID + return mock +} diff --git a/module/xep0045/disco_provider.go b/module/xep0045/disco_provider.go new file mode 100644 index 000000000..9a1b59b18 --- /dev/null +++ b/module/xep0045/disco_provider.go @@ -0,0 +1,203 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/module/xep0030" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +const ( + // implemented muc namespaces + mucNamespace = "http://jabber.org/protocol/muc" + mucNamespaceUser = "http://jabber.org/protocol/muc#user" + mucNamespaceOwner = "http://jabber.org/protocol/muc#owner" + mucNamespaceAdmin = "http://jabber.org/protocol/muc#admin" + mucNamespaceStableID = "http://jabber.org/protocol/muc#stable_id" + mucNamespaceRoomConfig = "http://jabber.org/protocol/muc#roomconfig" + + // implemented muc room types + mucHidden = "muc_hidden" + mucPublic = "muc_public" + mucMembersOnly = "muc_membersonly" + mucOpen = "muc_open" + mucModerated = "muc_moderated" + mucUnmoderated = "muc_unmoderated" + mucNonAnonymous = "muc_nonanonymous" + mucSemiAnonymous = "muc_semianonymous" + mucPwdProtected = "muc_passwordprotected" + mucUnsecured = "muc_unsecured" + mucPersistent = "muc_persistent" + mucTemporary = "muc_temporary" + + mucUserItem = "x-roomuser-item" +) + +// discoMucProvider represents a service discovery instance for the muc service +type discoMucProvider struct { + service *Muc +} + +// setupDiscoService adds muc discovery items to the xep0030, and registers discoMucProvider +func setupDiscoService(cfg *Config, disco *xep0030.DiscoInfo, mucService *Muc) { + // registering disco item for discovering a muc service + item := xep0030.Item{ + Jid: cfg.MucHost, + Name: cfg.Name, + } + disco.RegisterServerItem(item) + disco.RegisterServerFeature(mucNamespace) + + // registering the discoInfoProvider + provider := &discoMucProvider{ + service: mucService, + } + disco.RegisterProvider(cfg.MucHost, provider) +} + +func (p *discoMucProvider) Identities(ctx context.Context, toJID, fromJID *jid.JID, node string) []xep0030.Identity { + var identities []xep0030.Identity + if toJID != nil && toJID.Node() != "" { + room := p.getRoom(ctx, toJID) + if node == "" { + if room != nil { + identities = append(identities, xep0030.Identity{Type: "text", + Category: "conference", Name: room.Name}) + } + } else if node == mucUserItem { + if room != nil { + occJID, ok := room.GetOccupantJID(fromJID.ToBareJID()) + if ok { + identities = append(identities, xep0030.Identity{Type: "text", + Category: "conference", Name: occJID.Resource()}) + } + } + } + } else { + identities = append(identities, xep0030.Identity{Type: "text", Category: "conference", + Name: p.service.cfg.Name}) + } + return identities +} + +func (p *discoMucProvider) Features(ctx context.Context, toJID, _ *jid.JID, _ string) ([]xep0030.Feature, *xmpp.StanzaError) { + if toJID != nil && toJID.Node() != "" { + return p.roomFeatures(ctx, toJID) + } else { + return []string{mucNamespace}, nil + } +} + +func (p *discoMucProvider) Form(_ context.Context, _, _ *jid.JID, _ string) (*xep0004.DataForm, *xmpp.StanzaError) { + return nil, nil +} + +func (p *discoMucProvider) Items(ctx context.Context, toJID, _ *jid.JID, _ string) ([]xep0030.Item, *xmpp.StanzaError) { + if toJID != nil && toJID.Node() != "" { + return p.roomOccupants(ctx, toJID) + } else { + return p.publicRooms(ctx) + } +} + +func (p *discoMucProvider) roomOccupants(ctx context.Context, roomJID *jid.JID) ([]xep0030.Item, *xmpp.StanzaError) { + var items []xep0030.Item + room := p.getRoom(ctx, roomJID) + if room == nil { + return nil, xmpp.ErrItemNotFound + } + if room.Config.WhoCanGetMemberList() == mucmodel.All { + for _, occJID := range room.GetAllOccupantJIDs() { + items = append(items, xep0030.Item{Jid: occJID.String()}) + } + } + return items, nil +} + +func (p *discoMucProvider) publicRooms(ctx context.Context) ([]xep0030.Item, *xmpp.StanzaError) { + var items []xep0030.Item + p.service.mu.Lock() + for _, r := range p.service.allRooms { + room := p.getRoom(ctx, &r) + if room == nil { + return nil, xmpp.ErrInternalServerError + } + if room.Config.Public && !room.Locked { + item := xep0030.Item{ + Jid: room.RoomJID.String(), + Name: room.Name, + } + items = append(items, item) + } + } + p.service.mu.Unlock() + return items, nil +} + +func (p *discoMucProvider) roomFeatures(ctx context.Context, roomJID *jid.JID) ([]xep0030.Feature, *xmpp.StanzaError) { + room := p.getRoom(ctx, roomJID) + if room == nil { + return nil, xmpp.ErrItemNotFound + } + + features := getRoomFeatures(room) + + return features, nil +} + +func (p *discoMucProvider) getRoom(ctx context.Context, roomJID *jid.JID) *mucmodel.Room { + r, err := p.service.repRoom.FetchRoom(ctx, roomJID) + if err != nil { + return nil + } + return r +} + +func getRoomFeatures(room *mucmodel.Room) []string { + features := []string{mucNamespace, mucNamespaceStableID, mucNamespaceRoomConfig} + + if room.Config.Public { + features = append(features, mucPublic) + } else { + features = append(features, mucHidden) + } + + if room.Config.Open { + features = append(features, mucOpen) + } else { + features = append(features, mucMembersOnly) + } + + if room.Config.Moderated { + features = append(features, mucModerated) + } else { + features = append(features, mucUnmoderated) + } + + if room.Config.NonAnonymous { + features = append(features, mucNonAnonymous) + } else { + features = append(features, mucSemiAnonymous) + } + + if room.Config.PwdProtected { + features = append(features, mucPwdProtected) + } else { + features = append(features, mucUnsecured) + } + + if room.Config.Persistent { + features = append(features, mucPersistent) + } else { + features = append(features, mucTemporary) + } + return features +} diff --git a/module/xep0045/disco_provider_test.go b/module/xep0045/disco_provider_test.go new file mode 100644 index 000000000..662c433a3 --- /dev/null +++ b/module/xep0045/disco_provider_test.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "testing" + + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_DiscoIdentities(t *testing.T) { + neRoom, _ := jid.New("nonexistent_room", "conference.jackal.im", "", true) + sRoom, _ := jid.New("secretroom", "conference.jackal.im", "", true) + pubRoom, _ := jid.New("publicroom", "conference.jackal.im", "", true) + usrJID, _ := jid.New("ortuman", "jackal.im", "phone", true) + cfgJID, _ := jid.New("", "conference.jackal.im", "", true) + dp := setupDiscoTest() + + ids := dp.Identities(context.Background(), cfgJID, nil, "") + require.Len(t, ids, 1) + require.Equal(t, ids[0].Name, dp.service.cfg.Name) + + ids = dp.Identities(context.Background(), neRoom, nil, "") + require.Len(t, ids, 0) + + ids = dp.Identities(context.Background(), pubRoom, usrJID, mucUserItem) + require.Len(t, ids, 1) + require.Equal(t, ids[0].Name, "nick") + + ids = dp.Identities(context.Background(), sRoom, nil, "") + require.Len(t, ids, 1) + require.Equal(t, ids[0].Name, "Secret room") +} + +func TestXEP0045_DiscoFeatures(t *testing.T) { + neRoom, _ := jid.New("nonexistent_room", "conference.jackal.im", "", true) + sRoom, _ := jid.New("secretroom", "conference.jackal.im", "", true) + dp := setupDiscoTest() + + f, err := dp.Features(context.Background(), nil, nil, "") + require.Nil(t, err) + require.Len(t, f, 1) + require.Equal(t, f[0], mucNamespace) + + f, err = dp.Features(context.Background(), neRoom, nil, "") + require.Nil(t, f) + require.NotNil(t, err) + + f, err = dp.Features(context.Background(), sRoom, nil, "") + require.Nil(t, err) + require.Len(t, f, 9) + require.Equal(t, f[3], mucHidden) +} + +func TestXEP0045_DiscoItems(t *testing.T) { + neRoom, _ := jid.New("nonexistent_room", "conference.jackal.im", "", true) + pRoom, _ := jid.New("publicroom", "conference.jackal.im", "", true) + dp := setupDiscoTest() + + i, err := dp.Items(context.Background(), nil, nil, "") + require.Nil(t, err) + require.Len(t, i, 1) + require.Equal(t, i[0].Name, "Public room") + + i, err = dp.Items(context.Background(), neRoom, nil, "") + require.NotNil(t, err) + require.Nil(t, i) + + i, err = dp.Items(context.Background(), pRoom, nil, "") + require.Nil(t, err) + require.NotNil(t, i) + require.Len(t, i, 1) + require.Equal(t, i[0].Jid, "publicroom@conference.jackal.im/nick") +} + +func setupDiscoTest() *discoMucProvider { + mock := setupMockMucService() + mock.muc.cfg.Name = "Chat Service" + + hiddenRc := &mucmodel.RoomConfig{Public: false} + hJID, _ := jid.New("secretroom", "conference.jackal.im", "", true) + hiddenRoom := mucmodel.Room{ + Name: "Secret room", + Config: hiddenRc, + RoomJID: hJID, + } + + publicRc := &mucmodel.RoomConfig{Public: true} + pJID, _ := jid.New("publicroom", "conference.jackal.im", "", true) + publicRoom := mucmodel.Room{ + Name: "Public room", + Config: publicRc, + RoomJID: pJID, + } + publicRoom.Config.SetWhoCanGetMemberList("all") + oJID, _ := jid.New("publicroom", "conference.jackal.im", "nick", true) + usrJID, _ := jid.New("ortuman", "jackal.im", "phone", true) + o := &mucmodel.Occupant{OccupantJID: oJID, BareJID: usrJID.ToBareJID()} + publicRoom.AddOccupant(o) + + mock.muc.repRoom.UpsertRoom(context.Background(), &publicRoom) + mock.muc.repRoom.UpsertRoom(context.Background(), &hiddenRoom) + mock.muc.allRooms = append(mock.muc.allRooms, *hiddenRoom.RoomJID) + mock.muc.allRooms = append(mock.muc.allRooms, *publicRoom.RoomJID) + + return &discoMucProvider{service: mock.muc} +} diff --git a/module/xep0045/elements.go b/module/xep0045/elements.go new file mode 100644 index 000000000..d721b5088 --- /dev/null +++ b/module/xep0045/elements.go @@ -0,0 +1,303 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +// this elements.go file provides the helper funtions to manipulate the xmpp elements as specified +// in the xep-0045 specification + +import ( + "github.com/google/uuid" + "github.com/ortuman/jackal/log" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +func getRoomUpdatedElement(nonAnonymous, updatedAnonimity bool) *xmpp.Element { + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(newStatusElement("104")) + if updatedAnonimity { + if nonAnonymous { + xEl.AppendElement(newStatusElement("172")) + } else { + xEl.AppendElement(newStatusElement("173")) + } + } + msgEl := xmpp.NewElementName("message").SetID(uuid.New().String()).SetType("groupchat") + return msgEl.AppendElement(xEl) +} + +func getOccupantsInfoElement(occupants []*mucmodel.Occupant, id string, + includeUserJID bool) *xmpp.Element { + query := xmpp.NewElementNamespace("query", mucNamespaceAdmin) + for _, o := range occupants { + query.AppendElement(newOccupantItem(o, includeUserJID, true)) + } + iq := xmpp.NewElementName("iq").AppendElement(query) + iq.SetID("id").SetType("result") + return iq +} + +func getUserBannedElement(actor, reason string) *xmpp.Element { + actorEl := xmpp.NewElementName("actor").SetAttribute("nick", actor) + itemEl := xmpp.NewElementName("item").AppendElement(actorEl) + itemEl.SetAttribute("affiliation", "outcast") + itemEl.SetAttribute("role", "none") + if reason != "" { + reasonEl := xmpp.NewElementName("reason").SetText(reason) + itemEl.AppendElement(reasonEl) + } + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(itemEl) + xEl.AppendElement(newStatusElement("301")) + presence := xmpp.NewElementName("presence").SetType("unavailable").AppendElement(xEl) + return presence +} + +func getRoomMemberRemovedElement(actor, reason string) *xmpp.Element { + actorEl := xmpp.NewElementName("actor").SetAttribute("nick", actor) + itemEl := xmpp.NewElementName("item").AppendElement(actorEl) + itemEl.SetAttribute("affiliation", "none") + itemEl.SetAttribute("role", "none") + if reason != "" { + reasonEl := xmpp.NewElementName("reason").SetText(reason) + itemEl.AppendElement(reasonEl) + } + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(itemEl) + xEl.AppendElement(newStatusElement("321")) + presence := xmpp.NewElementName("presence").SetType("unavailable").AppendElement(xEl) + return presence +} + +// getReasonFromItem returns text from the reason element if specified, otherwise empty string +func getReasonFromItem(item xmpp.XElement) string { + reasonEl := item.Elements().Child("reason") + reason := "" + if reasonEl != nil { + reason = reasonEl.Text() + } + return reason +} + +func getOccupantChangeElement(o *mucmodel.Occupant, reason string) *xmpp.Element { + itemEl := xmpp.NewElementName("item") + itemEl.SetAttribute("affiliation", o.GetAffiliation()) + itemEl.SetAttribute("role", o.GetRole()) + itemEl.SetAttribute("nick", o.OccupantJID.Resource()) + if reason != "" { + reasonEl := xmpp.NewElementName("reason").SetText(reason) + itemEl.AppendElement(reasonEl) + } + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(itemEl) + return xmpp.NewElementName("presence").AppendElement(xEl) +} + +func getKickedOccupantElement(actor, reason string, selfNotifying bool) *xmpp.Element { + itemEl := xmpp.NewElementName("item").SetAttribute("affiliation", "none") + itemEl.SetAttribute("role", "none") + actorEl := xmpp.NewElementName("actor").SetAttribute("nick", actor) + itemEl.AppendElement(actorEl) + if reason != "" { + reasonEl := xmpp.NewElementName("reason").SetText(reason) + itemEl.AppendElement(reasonEl) + } + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(itemEl) + xEl.AppendElement(newStatusElement("307")) + if selfNotifying { + xEl.AppendElement(newStatusElement("110")) + } + pEl := xmpp.NewElementName("presence").SetType("unavailable").AppendElement(xEl) + return pEl +} + +// getInvitedUserJID returns jid as specified in the "to" attribute of the invite element +func getInvitedUserJID(message *xmpp.Message) *jid.JID { + invJIDStr := message.Elements().Child("x").Elements().Child("invite").Attributes().Get("to") + invJID, _ := jid.NewWithString(invJIDStr, true) + return invJID +} + +func getMessageElement(body xmpp.XElement, id string, private bool) *xmpp.Element { + msgEl := xmpp.NewElementName("message").AppendElement(body) + + if id != "" { + msgEl.SetID(id) + } else { + msgEl.SetID(uuid.New().String()) + } + + if private { + msgEl.SetType("chat") + msgEl.AppendElement(xmpp.NewElementNamespace("x", mucNamespaceUser)) + } else { + msgEl.SetType("groupchat") + } + + return msgEl +} + +func getDeclineStanza(room *mucmodel.Room, message *xmpp.Message) xmpp.Stanza { + toStr := message.Elements().Child("x").Elements().Child("decline").Attributes().Get("to") + to, _ := jid.NewWithString(toStr, true) + + declineEl := xmpp.NewElementName("decline").SetAttribute("from", + message.FromJID().ToBareJID().String()) + reasonEl := message.Elements().Child("x").Elements().Child("decline").Elements().Child("reason") + if reasonEl != nil { + declineEl.AppendElement(reasonEl) + } + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(declineEl) + msgEl := xmpp.NewElementName("message").AppendElement(xEl).SetID(message.ID()) + msg, err := xmpp.NewMessageFromElement(msgEl, room.RoomJID, to) + if err != nil { + log.Error(err) + return nil + } + return msg +} + +func getInvitationStanza(room *mucmodel.Room, inviteFrom, inviteTo *jid.JID, message *xmpp.Message) xmpp.Stanza { + inviteEl := xmpp.NewElementName("invite").SetAttribute("from", inviteFrom.String()) + reasonEl := message.Elements().Child("x").Elements().Child("invite").Elements().Child("reason") + if reasonEl != nil { + inviteEl.AppendElement(reasonEl) + } + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(inviteEl) + if room.Config.PwdProtected { + pwdEl := xmpp.NewElementName("password").SetText(room.Config.Password) + xEl.AppendElement(pwdEl) + } + msgEl := xmpp.NewElementName("message").AppendElement(xEl).SetID(message.ID()) + msg, err := xmpp.NewMessageFromElement(msgEl, room.RoomJID, inviteTo) + if err != nil { + log.Error(err) + return nil + } + return msg +} + +func getOccupantUnavailableElement(o *mucmodel.Occupant, selfNotifying, + includeUserJID bool) *xmpp.Element { + // get the x element + x := xmpp.NewElementNamespace("x", mucNamespaceUser) + x.AppendElement(newOccupantItem(o, includeUserJID, true)) + x.AppendElement(newStatusElement("303")) + if selfNotifying { + x.AppendElement(newStatusElement("110")) + } + + el := xmpp.NewElementName("presence").AppendElement(x).SetID(uuid.New().String()) + el.SetType("unavailable") + return el +} + +func getPasswordFromPresence(presence *xmpp.Presence) string { + x := presence.Elements().ChildNamespace("x", mucNamespace) + if x == nil { + return "" + } + pwd := x.Elements().Child("password") + if pwd == nil { + return "" + } + return pwd.Text() +} + +func getOccupantStatusElement(o *mucmodel.Occupant, selfNotifying, + includeUserJID bool) *xmpp.Element { + x := newOccupantAffiliationRoleElement(o, includeUserJID, false) + if selfNotifying { + x.AppendElement(newStatusElement("110")) + } + el := xmpp.NewElementName("presence").AppendElement(x).SetID(uuid.New().String()) + return el +} + +func getOccupantSelfPresenceElement(o *mucmodel.Occupant, nonAnonymous bool, + id string) *xmpp.Element { + x := newOccupantAffiliationRoleElement(o, false, false) + x.AppendElement(newStatusElement("110")) + if nonAnonymous { + x.AppendElement(newStatusElement("100")) + } + return xmpp.NewElementName("presence").AppendElement(x).SetID(id) +} + +func getRoomSubjectElement(subject string) *xmpp.Element { + s := xmpp.NewElementName("subject").SetText(subject) + m := xmpp.NewElementName("message").SetType("groupchat").SetID(uuid.New().String()) + return m.AppendElement(s) +} + +func getAckStanza(from, to *jid.JID) xmpp.Stanza { + item := xmpp.NewElementName("item") + item.SetAttribute("affiliation", "owner").SetAttribute("role", "moderator") + e := xmpp.NewElementNamespace("x", mucNamespaceUser) + e.AppendElement(item) + e.AppendElement(newStatusElement("110")) + e.AppendElement(newStatusElement("210")) + + presence := xmpp.NewElementName("presence").AppendElement(e) + ack, err := xmpp.NewPresenceFromElement(presence, from, to) + if err != nil { + log.Error(err) + return nil + } + return ack +} + +func getFormStanza(iq *xmpp.IQ, form *xep0004.DataForm) xmpp.Stanza { + query := xmpp.NewElementNamespace("query", mucNamespaceOwner) + query.AppendElement(form.Element()) + + e := xmpp.NewElementName("iq").SetID(iq.ID()).SetType("result").AppendElement(query) + stanza, err := xmpp.NewIQFromElement(e, iq.ToJID(), iq.FromJID()) + if err != nil { + log.Error(err) + return nil + } + return stanza +} + +func newStatusElement(code string) *xmpp.Element { + s := xmpp.NewElementName("status") + s.SetAttribute("code", code) + return s +} + +func newOccupantItem(o *mucmodel.Occupant, includeUserJID, includeNick bool) *xmpp.Element { + i := xmpp.NewElementName("item") + a := o.GetAffiliation() + r := o.GetRole() + if a == "" { + a = "none" + } + if r == "" { + r = "none" + } + i.SetAttribute("affiliation", a) + i.SetAttribute("role", r) + if includeUserJID { + i.SetAttribute("jid", o.BareJID.String()) + } + if includeNick { + i.SetAttribute("nick", o.OccupantJID.Resource()) + } + return i +} + +func newOccupantAffiliationRoleElement(o *mucmodel.Occupant, includeUserJID, + includeNick bool) *xmpp.Element { + item := newOccupantItem(o, includeUserJID, includeNick) + e := xmpp.NewElementNamespace("x", mucNamespaceUser) + e.AppendElement(item) + return e +} + +// addResourceToBareJID joins bareJID and resource into a full jid +func addResourceToBareJID(bareJID *jid.JID, resource string) *jid.JID { + res, _ := jid.NewWithString(bareJID.String()+"/"+resource, true) + return res +} diff --git a/module/xep0045/elements_test.go b/module/xep0045/elements_test.go new file mode 100644 index 000000000..aa11b0fee --- /dev/null +++ b/module/xep0045/elements_test.go @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "testing" + + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_NewStatus(t *testing.T) { + status := newStatusElement("200") + require.Equal(t, status.Name(), "status") + require.Equal(t, status.Attributes().Get("code"), "200") +} + +func TestXEP0045_GetAckStanza(t *testing.T) { + from, _ := jid.New("ortuman", "test.org", "balcony", false) + to, _ := jid.New("ortuman", "example.org", "garden", false) + message := getAckStanza(from, to) + require.Equal(t, message.Name(), "presence") + require.Equal(t, message.From(), from.String()) + require.Equal(t, message.To(), to.String()) + + xel := message.Elements().Child("x") + require.Equal(t, xel.Namespace(), mucNamespaceUser) +} + +func TestXEP0045_GetFormStanza(t *testing.T) { + from, _ := jid.New("ortuman", "test.org", "balcony", false) + to, _ := jid.New("ortuman", "example.org", "garden", false) + r, c := setupTest("jackal.im") + muc := New(&Config{MucHost: "conference.jackal.im"}, nil, r, c.Room(), c.Occupant()) + + iq := &xmpp.IQ{} + iq.SetFromJID(from) + iq.SetToJID(to) + iq.SetID("create") + + room := &mucmodel.Room{Config: &mucmodel.RoomConfig{}} + form := muc.getRoomConfigForm(context.Background(), room) + require.NotNil(t, form) + require.Len(t, form.Fields, 17) + + formStanza := getFormStanza(iq, form) + require.NotNil(t, formStanza) +} + +func TestXEP0045_InstantRoomCreateIQ(t *testing.T) { + from, _ := jid.New("ortuman", "jackal.im", "balcony", true) + to, _ := jid.New("room", "conference.jackal.im", "", true) + + falseX := xmpp.NewElementNamespace("x", xep0004.FormNamespace).SetAttribute("type", "not_submit") + falseQuery := xmpp.NewElementNamespace("query", mucNamespaceOwner).AppendElement(falseX) + falseIQ := xmpp.NewElementName("iq").SetID("create1").SetType("set").AppendElement(falseQuery) + falseRequest, _ := xmpp.NewIQFromElement(falseIQ, from, to) + require.False(t, isIQForInstantRoomCreate(falseRequest)) + + x := xmpp.NewElementNamespace("x", xep0004.FormNamespace).SetAttribute("type", "submit") + query := xmpp.NewElementNamespace("query", mucNamespaceOwner).AppendElement(x) + iq := xmpp.NewElementName("iq").SetID("create1").SetType("set").AppendElement(query) + request, _ := xmpp.NewIQFromElement(iq, from, to) + require.True(t, isIQForInstantRoomCreate(request)) +} diff --git a/module/xep0045/iq.go b/module/xep0045/iq.go new file mode 100644 index 000000000..4b5ece2d4 --- /dev/null +++ b/module/xep0045/iq.go @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "fmt" + + "github.com/ortuman/jackal/log" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +// isIQForRoomDestroy returns true if iq stanza is for destroying a room +func isIQForRoomDestroy(iq *xmpp.IQ) bool { + if !iq.IsSet() { + return false + } + query := iq.Elements().Child("query") + destroy := query.Elements().Child("destroy") + if destroy == nil { + return false + } + return true +} + +// destroyRoom proceses the iq aimed at destroying an existing muc room +func (s *Muc) destroyRoom(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + owner, errStanza := s.getOccupantFromStanza(ctx, room, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + // notify occupants in the room that the room is destroyed + err := s.notifyRoomDestroyed(ctx, owner, room, iq) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, iq.InternalServerError()) + return + } + + s.deleteRoom(ctx, room) + _ = s.router.Route(ctx, iq.ResultIQ()) +} + +func (s *Muc) notifyRoomDestroyed(ctx context.Context, owner *mucmodel.Occupant, + room *mucmodel.Room, iq *xmpp.IQ) error { + // the actor destroying the room is sent to all of the occupants in the item element + owner.SetAffiliation("") + owner.SetRole("") + + // create the stanza to notify the room + itemEl := newOccupantItem(owner, false, false) + destroyEl := iq.Elements().Child("query").Elements().Child("destroy") + xEl := xmpp.NewElementNamespace("x", mucNamespaceUser) + xEl.AppendElement(itemEl).AppendElement(destroyEl) + presenceEl := xmpp.NewElementName("presence").SetType("unavailable").AppendElement(xEl) + + for _, occJID := range room.GetAllOccupantJIDs() { + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return err + } + err = s.sendPresenceToOccupant(ctx, o, o.OccupantJID, presenceEl) + if err != nil { + return err + } + } + return nil +} + +// modifyOccupantList handles the iq stanzas sent to the muc admin namespace of type set +func (s *Muc) modifyOccupantList(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + sender, errStanza := s.getOccupantFromStanza(ctx, room, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + query := iq.Elements().Child("query") + items := query.Elements().Children("item") + // one item per occupant whose privilege is being changed + for _, item := range items { + err := s.modifyOccupantPrivilege(ctx, room, sender, item) + if err != nil { + _ = s.router.Route(ctx, iq.BadRequestError()) + return + } + } + + _ = s.router.Route(ctx, iq.ResultIQ()) +} + +// modifyOccupantPrivilege changes occupants role/affiliation as specified in the item element +func (s *Muc) modifyOccupantPrivilege(ctx context.Context, room *mucmodel.Room, + sender *mucmodel.Occupant, item xmpp.XElement) error { + role := item.Attributes().Get("role") + affiliation := item.Attributes().Get("affiliation") + + var err error + switch { + case role != "": + err = s.modifyOccupantRole(ctx, room, sender, item) + case affiliation != "": + err = s.modifyOccupantAffiliation(ctx, room, sender, item) + default: + err = fmt.Errorf("Role and affiliation not specified") + } + return err +} + +func (s *Muc) modifyOccupantRole(ctx context.Context, room *mucmodel.Room, + sender *mucmodel.Occupant, item xmpp.XElement) error { + occ, newRole := s.getOccupantAndNewRole(ctx, room, item) + if occ == nil { + return fmt.Errorf("Occupant not in the room") + } + + if !sender.CanChangeRole(occ, newRole) { + return fmt.Errorf("Sender not allowed to change the role") + } + + reason := getReasonFromItem(item) + if newRole == "none" { + err := s.kickOccupant(ctx, room, occ, sender.OccupantJID.Resource(), reason) + if err != nil { + return err + } + } else { + occ.SetRole(newRole) + s.repOccupant.UpsertOccupant(ctx, occ) + + occEl := getOccupantChangeElement(occ, reason) + err := s.sendPresenceToRoom(ctx, room, occ.OccupantJID, occEl) + if err != nil { + return err + } + } + + return nil +} + +func (s *Muc) getOccupantAndNewRole(ctx context.Context, room *mucmodel.Room, + item xmpp.XElement) (*mucmodel.Occupant, string) { + occNick := item.Attributes().Get("nick") + occJID := addResourceToBareJID(room.RoomJID, occNick) + occ, err := s.repOccupant.FetchOccupant(ctx, occJID) + if err != nil || occ == nil { + return nil, "" + } + newRole := item.Attributes().Get("role") + return occ, newRole +} + +func (s *Muc) kickOccupant(ctx context.Context, room *mucmodel.Room, kickedOcc *mucmodel.Occupant, + actor, reason string) error { + kickedOcc.SetAffiliation("") + kickedOcc.SetRole("") + s.occupantExitsRoom(ctx, room, kickedOcc) + + kickedElSelf := getKickedOccupantElement(actor, reason, true) + err := s.sendPresenceToOccupant(ctx, kickedOcc, kickedOcc.OccupantJID, kickedElSelf) + if err != nil { + return err + } + + kickedElRoom := getKickedOccupantElement(actor, reason, false) + err = s.sendPresenceToRoom(ctx, room, kickedOcc.OccupantJID, kickedElRoom) + if err != nil { + return err + } + + return nil +} + +func (s *Muc) modifyOccupantAffiliation(ctx context.Context, room *mucmodel.Room, + sender *mucmodel.Occupant, item xmpp.XElement) error { + occ, newAffiliation := s.getOccupantAndNewAffiliation(ctx, room, item) + if occ == nil { + return fmt.Errorf("Occupant not in the room") + } + + if !sender.CanChangeAffiliation(occ, newAffiliation) { + return fmt.Errorf("Sender not allowed to change the affiliation") + } + + occ.SetAffiliation(newAffiliation) + room.SetDefaultRole(occ) + s.repOccupant.UpsertOccupant(ctx, occ) + + reason := getReasonFromItem(item) + occEl := getOccupantChangeElement(occ, reason) + err := s.sendPresenceToRoom(ctx, room, occ.OccupantJID, occEl) + if err != nil { + return err + } + + if newAffiliation == "none" || newAffiliation == "outcast" { + err = s.handleUserRemoval(ctx, room, sender, occ, newAffiliation, reason) + if err != nil { + return err + } + } + + return nil +} + +func (s *Muc) getOccupantAndNewAffiliation(ctx context.Context, room *mucmodel.Room, + item xmpp.XElement) (*mucmodel.Occupant, string) { + userBareJIDStr := item.Attributes().Get("jid") + userBareJID, err := jid.NewWithString(userBareJIDStr, true) + if err != nil { + return nil, "" + } + occJID, ok := room.GetOccupantJID(userBareJID) + if !ok { + return nil, "" + } + occ, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil || occ == nil { + return nil, "" + } + newAff := item.Attributes().Get("affiliation") + return occ, newAff +} + +func (s *Muc) handleUserRemoval(ctx context.Context, room *mucmodel.Room, sender, occ *mucmodel.Occupant, + newAffiliation, reason string) error { + if !room.Config.Open && newAffiliation == "none" { + removedEl := getRoomMemberRemovedElement(sender.OccupantJID.Resource(), reason) + err := s.sendPresenceToRoom(ctx, room, occ.OccupantJID, removedEl) + if err != nil { + return err + } + room.OccupantLeft(occ) + s.repOccupant.DeleteOccupant(ctx, occ.OccupantJID) + } else if newAffiliation == "outcast" { + bannedEl := getUserBannedElement(sender.OccupantJID.Resource(), reason) + err := s.sendPresenceToRoom(ctx, room, occ.OccupantJID, bannedEl) + if err != nil { + return err + } + } + return nil +} + +// getOccupantList handles the iq stanzas sent to the muc admin namespace of type get +func (s *Muc) getOccupantList(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + sender, errStanza := s.getOccupantFromStanza(ctx, room, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + // resOccupants is the list of occupants that matches the role/affiliation from iq + resOccupants, errStanza := s.getRequestedOccupants(ctx, room, sender, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + listEl := getOccupantsInfoElement(resOccupants, iq.ID(), + room.Config.OccupantCanDiscoverRealJID(sender)) + iqRes, _ := xmpp.NewIQFromElement(listEl, room.RoomJID, iq.FromJID()) + _ = s.router.Route(ctx, iqRes) +} + +func (s *Muc) getRequestedOccupants(ctx context.Context, room *mucmodel.Room, + sender *mucmodel.Occupant, iq *xmpp.IQ) ([]*mucmodel.Occupant, xmpp.Stanza) { + switch filter := getFilterFromIQ(iq); filter { + case "moderator", "participant", "visitor": + resOccupants, err := s.getOccupantsByRole(ctx, room, sender, filter) + if err != nil { + return nil, iq.NotAllowedError() + } + return resOccupants, nil + case "owner", "admin", "member", "outcast": + resOccupants, err := s.getOccupantsByAffiliation(ctx, room, sender, filter) + if err != nil { + return nil, iq.NotAllowedError() + } + return resOccupants, nil + } + + return nil, iq.BadRequestError() +} + +func getFilterFromIQ(iq *xmpp.IQ) string { + item := iq.Elements().Child("query").Elements().Child("item") + if item == nil { + return "" + } + aff := item.Attributes().Get("affiliation") + if aff != "" { + return aff + } + return item.Attributes().Get("role") +} + +// isIQForInstantRoomCreate returns true if iq stanza is for creating an instant room +func isIQForInstantRoomCreate(iq *xmpp.IQ) bool { + if !iq.IsSet() { + return false + } + query := iq.Elements().Child("query") + x := query.Elements().Child("x") + if x == nil { + return false + } + if x.Namespace() != "jabber:x:data" || x.Type() != "submit" || x.Elements().Count() != 0 { + return false + } + return true +} + +// createInstantRoom unlocks the existing room specified in iq stanza +func (s *Muc) createInstantRoom(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + _, errStanza := s.getOwnerFromIQ(ctx, room, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + room.Locked = false + err := s.repRoom.UpsertRoom(ctx, room) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, iq.InternalServerError()) + } + + _ = s.router.Route(ctx, iq.ResultIQ()) +} + +// isIQForRoomConfigRequest returns true if iq stanza is for retrieving a room configuration form +func isIQForRoomConfigRequest(iq *xmpp.IQ) bool { + if !iq.IsGet() { + return false + } + query := iq.Elements().Child("query") + if query.Elements().Count() != 0 { + return false + } + return true +} + +// sendRoomConfiguration returns the room configuration form to the roow owner who requested it +func (s *Muc) sendRoomConfiguration(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + _, errStanza := s.getOwnerFromIQ(ctx, room, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + configForm := s.getRoomConfigForm(ctx, room) + stanza := getFormStanza(iq, configForm) + _ = s.router.Route(ctx, stanza) +} + +// isIQForRoomConfigSubmission returns true if iq stanza is for submitting a room configuration +func isIQForRoomConfigSubmission(iq *xmpp.IQ) bool { + if !iq.IsSet() { + return false + } + query := iq.Elements().Child("query") + form := query.Elements().Child("x") + if form == nil || form.Namespace() != xep0004.FormNamespace || form.Type() != "submit" { + return false + } + return true +} + +// processRoomConfiguration handles the iq modifying the existing's room config +func (s *Muc) processRoomConfiguration(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + _, errStanza := s.getOwnerFromIQ(ctx, room, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + formEl := iq.Elements().Child("query").Elements().Child("x") + switch formEl.Type() { + case "submit": + errStanza := s.configureRoom(ctx, room, formEl, iq) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + case "cancel": + if room.Locked { + s.deleteRoom(ctx, room) + } + default: + _ = s.router.Route(ctx, iq.BadRequestError()) + return + } + + _ = s.router.Route(ctx, iq.ResultIQ()) +} + +func (s *Muc) configureRoom(ctx context.Context, room *mucmodel.Room, formEl xmpp.XElement, + iq *xmpp.IQ) xmpp.Stanza { + form, err := xep0004.NewFormFromElement(formEl) + if err != nil { + return iq.BadRequestError() + } + + updatedAnonimity, ok := s.updateRoomWithForm(ctx, room, form) + if !ok { + return iq.NotAcceptableError() + } + + updatedRoomEl := getRoomUpdatedElement(room.Config.NonAnonymous, updatedAnonimity) + s.sendMessageToRoom(ctx, room, room.RoomJID, updatedRoomEl) + return nil +} diff --git a/module/xep0045/iq_test.go b/module/xep0045/iq_test.go new file mode 100644 index 000000000..5e1252286 --- /dev/null +++ b/module/xep0045/iq_test.go @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/xmpp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_DestroyRoom(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + + reasonEl := xmpp.NewElementName("reason").SetText("Reason for destroying") + destroyEl := xmpp.NewElementName("destroy").AppendElement(reasonEl) + queryEl := xmpp.NewElementNamespace("query", mucNamespaceOwner).AppendElement(destroyEl) + iqEl := xmpp.NewElementName("iq").SetType("set").SetID("destroy1").AppendElement(queryEl) + iq, _ := xmpp.NewIQFromElement(iqEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.destroyRoom(nil, mock.room, iq) + + ackOcc := mock.occStm.ReceiveElement() + require.Equal(t, ackOcc.From(), mock.occ.OccupantJID.String()) + require.Equal(t, ackOcc.Type(), "unavailable") + reason := ackOcc.Elements().Child("x").Elements().Child("destroy").Elements().Child("reason") + require.Equal(t, reason.Text(), "Reason for destroying") + + ownerAck := mock.ownerStm.ReceiveElement() + require.Equal(t, ownerAck.From(), mock.owner.OccupantJID.String()) + require.Equal(t, ownerAck.Type(), "unavailable") + ownerAck = mock.ownerStm.ReceiveElement() + require.Equal(t, ownerAck.Type(), "result") + + room, err := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + require.Nil(t, err) + require.Nil(t, room) + owner, err := mock.muc.repOccupant.FetchOccupant(nil, mock.owner.OccupantJID) + require.Nil(t, err) + require.Nil(t, owner) + occ, err := mock.muc.repOccupant.FetchOccupant(nil, mock.occ.OccupantJID) + require.Nil(t, err) + require.Nil(t, occ) +} + +func TestXEP0045_GetOccupantList(t *testing.T) { + mock := setupTestRoomAndOwner() + + itemEl := xmpp.NewElementName("item").SetAttribute("role", "moderator") + queryEl := xmpp.NewElementNamespace("query", mucNamespaceAdmin) + queryEl.AppendElement(itemEl) + iqEl := xmpp.NewElementName("iq").SetID("admin1").SetType("get") + iqEl.AppendElement(queryEl) + iq, _ := xmpp.NewIQFromElement(iqEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.getOccupantList(nil, mock.room, iq) + + resAck := mock.ownerStm.ReceiveElement() + require.Equal(t, resAck.Type(), "result") + query := resAck.Elements().Child("query") + require.NotNil(t, query) + require.Equal(t, query.Namespace(), mucNamespaceAdmin) + item := query.Elements().Child("item") + require.NotNil(t, item) + require.Equal(t, item.Attributes().Get("nick"), mock.owner.OccupantJID.Resource()) +} + +func TestXEP0045_ChangeAffiliation(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + require.False(t, mock.occ.IsAdmin()) + + reasonEl := xmpp.NewElementName("reason").SetText("reason for affiliation change") + itemEl := xmpp.NewElementName("item").SetAttribute("jid", + mock.occ.BareJID.String()) + itemEl.SetAttribute("affiliation", "admin").AppendElement(reasonEl) + queryEl := xmpp.NewElementNamespace("query", mucNamespaceAdmin) + queryEl.AppendElement(itemEl) + iqEl := xmpp.NewElementName("iq").SetID("admin1").SetType("set") + iqEl.AppendElement(queryEl) + iq, _ := xmpp.NewIQFromElement(iqEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.modifyOccupantList(nil, mock.room, iq) + + acAck := mock.occStm.ReceiveElement() + require.Equal(t, acAck.From(), mock.occ.OccupantJID.String()) + resArAck := mock.ownerStm.ReceiveElement() + require.Equal(t, resArAck.From(), mock.occ.OccupantJID.String()) + resAck := mock.ownerStm.ReceiveElement() + require.Equal(t, resAck.Type(), "result") + + resOcc, _ := mock.muc.repOccupant.FetchOccupant(nil, mock.occ.OccupantJID) + require.True(t, resOcc.IsAdmin()) +} + +func TestXEP0045_ChangeRole(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + require.False(t, mock.occ.IsModerator()) + + reasonEl := xmpp.NewElementName("reason").SetText("reason for role change") + itemEl := xmpp.NewElementName("item").SetAttribute("nick", + mock.occ.OccupantJID.Resource()) + itemEl.SetAttribute("role", "moderator").AppendElement(reasonEl) + queryEl := xmpp.NewElementNamespace("query", mucNamespaceAdmin) + queryEl.AppendElement(itemEl) + iqEl := xmpp.NewElementName("iq").SetID("mod1").SetType("set") + iqEl.AppendElement(queryEl) + iq, _ := xmpp.NewIQFromElement(iqEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.modifyOccupantList(nil, mock.room, iq) + + rcAck := mock.occStm.ReceiveElement() + require.Equal(t, rcAck.From(), mock.occ.OccupantJID.String()) + resCrAck := mock.ownerStm.ReceiveElement() + require.Equal(t, resCrAck.From(), mock.occ.OccupantJID.String()) + resAck := mock.ownerStm.ReceiveElement() + require.Equal(t, resAck.Type(), "result") + + resOcc, _ := mock.muc.repOccupant.FetchOccupant(nil, mock.occ.OccupantJID) + require.True(t, resOcc.IsModerator()) +} + +func TestXEP0045_KickOccupant(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + + reasonEl := xmpp.NewElementName("reason").SetText("reason for kicking") + itemEl := xmpp.NewElementName("item").SetAttribute("nick", + mock.occ.OccupantJID.Resource()) + itemEl.SetAttribute("role", "none").AppendElement(reasonEl) + queryEl := xmpp.NewElementNamespace("query", mucNamespaceAdmin) + queryEl.AppendElement(itemEl) + iqEl := xmpp.NewElementName("iq").SetID("kick1").SetType("set") + iqEl.AppendElement(queryEl) + iq, _ := xmpp.NewIQFromElement(iqEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.modifyOccupantList(nil, mock.room, iq) + + kickedAck := mock.occStm.ReceiveElement() + require.Equal(t, kickedAck.Type(), "unavailable") + resKickAck := mock.ownerStm.ReceiveElement() + require.Equal(t, resKickAck.Type(), "unavailable") + resAck := mock.ownerStm.ReceiveElement() + require.Equal(t, resAck.Type(), "result") + + _, found := mock.room.GetOccupantJID(mock.occ.BareJID) + require.False(t, found) + kicked, _ := mock.muc.repOccupant.FetchOccupant(nil, mock.occ.OccupantJID) + require.Nil(t, kicked) +} + +func TestXEP0045_CreateInstantRoom(t *testing.T) { + mock := setupTestRoomAndOwner() + mock.room.Locked = true + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + x := xmpp.NewElementNamespace("x", xep0004.FormNamespace) + x.SetAttribute("type", "submit") + query := xmpp.NewElementNamespace("query", mucNamespaceOwner).AppendElement(x) + iq := xmpp.NewElementName("iq").SetID("create1").SetType("set") + iq.AppendElement(query) + request, _ := xmpp.NewIQFromElement(iq, mock.ownerFullJID, mock.room.RoomJID) + + require.True(t, isIQForInstantRoomCreate(request)) + mock.muc.createInstantRoom(context.Background(), mock.room, request) + + ack := mock.ownerStm.ReceiveElement() + require.Equal(t, ack, request.ResultIQ()) + + updatedRoom, _ := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + require.False(t, updatedRoom.Locked) +} + +func TestXEP0045_SendRoomConfiguration(t *testing.T) { + mock := setupTestRoomAndOwner() + mock.room.Locked = true + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + query := xmpp.NewElementNamespace("query", mucNamespaceOwner) + iq := xmpp.NewElementName("iq").SetID("create1").SetType("get") + iq.AppendElement(query) + request, _ := xmpp.NewIQFromElement(iq, mock.ownerFullJID, mock.room.RoomJID) + + require.True(t, mock.muc.MatchesIQ(request)) + require.True(t, isIQForRoomConfigRequest(request)) + mock.muc.sendRoomConfiguration(context.Background(), mock.room, request) + + ack := mock.ownerStm.ReceiveElement() + require.Equal(t, ack.From(), mock.room.RoomJID.String()) + require.Equal(t, ack.To(), mock.ownerFullJID.String()) + require.Equal(t, ack.Name(), "iq") + require.Equal(t, ack.Type(), "result") + require.Equal(t, ack.ID(), "create1") + + queryResult := ack.Elements().Child("query") + require.NotNil(t, queryResult) + require.Equal(t, queryResult.Namespace(), mucNamespaceOwner) + + formElement := queryResult.Elements().Child("x") + require.NotNil(t, formElement) + form, err := xep0004.NewFormFromElement(formElement) + require.Nil(t, err) + require.Equal(t, form.Type, xep0004.Form) + require.Equal(t, len(form.Fields), 17) +} + +func TestXEP0045_ProcessRoomConfiguration(t *testing.T) { + mock := setupTestRoomAndOwner() + mock.room.Locked = true + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + require.True(t, mock.room.Locked) + require.NotEqual(t, mock.room.Name, "Configured Room") + require.NotEqual(t, mock.room.Config.MaxOccCnt, 23) + require.False(t, mock.room.Config.Public) + require.False(t, mock.room.Config.NonAnonymous) + + configForm := mock.muc.getRoomConfigForm(context.Background(), mock.room) + require.NotNil(t, configForm) + configForm.Type = xep0004.Submit + for i, field := range configForm.Fields { + switch field.Var { + case ConfigName: + configForm.Fields[i].Values = []string{"Configured Room"} + case ConfigMaxUsers: + configForm.Fields[i].Values = []string{"23"} + case ConfigWhoIs: + configForm.Fields[i].Values = []string{"1"} + case ConfigPublic: + configForm.Fields[i].Values = []string{"0"} + } + } + + query := xmpp.NewElementNamespace("query", mucNamespaceOwner) + query.AppendElement(configForm.Element()) + e := xmpp.NewElementName("iq").SetID("create").SetType("set").AppendElement(query) + stanza, err := xmpp.NewIQFromElement(e, mock.ownerFullJID, mock.room.RoomJID) + require.Nil(t, err) + + require.True(t, isIQForRoomConfigSubmission(stanza)) + mock.muc.processRoomConfiguration(context.Background(), mock.room, stanza) + + ack := mock.ownerStm.ReceiveElement() + assert.EqualValues(t, ack.Type(), "groupchat") + + confRoom, err := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + require.Nil(t, err) + require.False(t, confRoom.Locked) + require.Equal(t, confRoom.Name, "Configured Room") + require.Equal(t, confRoom.Config.MaxOccCnt, 23) + require.False(t, confRoom.Config.Public) + require.True(t, confRoom.Config.NonAnonymous) +} diff --git a/module/xep0045/message.go b/module/xep0045/message.go new file mode 100644 index 000000000..6c96091a3 --- /dev/null +++ b/module/xep0045/message.go @@ -0,0 +1,390 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "fmt" + "strconv" + + "github.com/google/uuid" + "github.com/ortuman/jackal/log" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +// isVoiceRequest returns true if the message stanza is asking for a voice +func isVoiceRequest(message *xmpp.Message) bool { + x := message.Elements().Child("x") + if x == nil || x.Namespace() != "jabber:x:data" || x.Type() != "submit" { + return false + } + return true +} + +// voiceRequest handles the request for voice sent in the message stanza +func (s *Muc) voiceRequest(ctx context.Context, room *mucmodel.Room, message *xmpp.Message) { + occ, errStanza := s.getOccupantFromStanza(ctx, room, message) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + // asking for voice makes sense only in a moderated room + if !room.Config.Moderated { + _ = s.router.Route(ctx, message.NotAllowedError()) + return + } + + // visitor is asking for a voice, moderator is approving it + switch { + case occ.IsVisitor(): + errStanza = s.askForVoice(ctx, room, occ, message) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + } + case occ.IsModerator(): + errStanza = s.approveVoiceRequest(ctx, room, message) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + } + } +} + +// askForVoice sends the voice request form to the visitor who requested it +func (s *Muc) askForVoice(ctx context.Context, room *mucmodel.Room, visitor *mucmodel.Occupant, + message *xmpp.Message) xmpp.Stanza { + formEl := message.Elements().Child("x") + form, err := xep0004.NewFormFromElement(formEl) + if err != nil { + return message.BadRequestError() + } + + if form.Fields.ValueForFieldOfType("muc#role", xep0004.ListSingle) != "participant" { + return message.NotAllowedError() + } + + approvalForm := s.getVoiceRequestForm(ctx, visitor) + msg := xmpp.NewElementName("message").SetID(uuid.New().String()) + msg.AppendElement(approvalForm.Element()) + for _, occJID := range room.GetAllOccupantJIDs() { + o, _ := s.repOccupant.FetchOccupant(ctx, &occJID) + if o.IsModerator() { + s.sendMessageToOccupant(ctx, o, room.RoomJID, msg) + } + } + + return nil +} + +// approveVoiceRequest processes the moderator's form submission of the voice request +func (s *Muc) approveVoiceRequest(ctx context.Context, room *mucmodel.Room, + message *xmpp.Message) xmpp.Stanza { + formEl := message.Elements().Child("x") + form, err := xep0004.NewFormFromElement(formEl) + if err != nil { + return message.BadRequestError() + } + + occ, err := s.processVoiceApprovalForm(ctx, room, form) + if err != nil { + return message.BadRequestError() + } + if occ != nil { + presenceEl := getOccupantChangeElement(occ, "") + s.sendPresenceToRoom(ctx, room, occ.OccupantJID, presenceEl) + } + + return nil +} + +func (s *Muc) processVoiceApprovalForm(ctx context.Context, room *mucmodel.Room, + form *xep0004.DataForm) (*mucmodel.Occupant, error) { + requestAllow := false + var role, userJIDStr, nick string + for _, field := range form.Fields { + if len(field.Values) == 0 { + continue + } + switch field.Var { + case "muc#role": + role = field.Values[0] + case "muc#jid": + userJIDStr = field.Values[0] + case "muc#roomnick": + nick = field.Values[0] + case "muc#request_allow": + requestAllow, _ = strconv.ParseBool(field.Values[0]) + } + } + + if requestAllow { + occ, err := s.approveVoice(ctx, room, userJIDStr, role, nick) + if err != nil { + return nil, err + } + return occ, nil + } + + return nil, nil +} + +func (s *Muc) approveVoice(ctx context.Context, room *mucmodel.Room, userJIDStr, role, + nick string) (*mucmodel.Occupant, error) { + userJID, err := jid.NewWithString(userJIDStr, false) + if err != nil { + return nil, err + } + occJID, ok := room.GetOccupantJID(userJID.ToBareJID()) + if !ok { + return nil, fmt.Errorf("User not in the room") + } + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return nil, err + } + if role != "participant" || nick != o.OccupantJID.Resource() { + return nil, fmt.Errorf("Form not filled out correctly") + } + + o.SetRole("participant") + err = s.repOccupant.UpsertOccupant(ctx, o) + if err != nil { + return nil, err + } + + return o, nil +} + +// changeSubject handles the message stanza that is changing the room subject +func (s *Muc) changeSubject(ctx context.Context, room *mucmodel.Room, message *xmpp.Message) { + occ, errStanza := s.getOccupantFromStanza(ctx, room, message) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + newSubject := message.Elements().Child("subject").Text() + if room.Config.OccupantCanChangeSubject(occ) { + room.Subject = newSubject + s.repRoom.UpsertRoom(ctx, room) + } else { + _ = s.router.Route(ctx, message.ForbiddenError()) + return + } + + subjectEl := xmpp.NewElementName("subject").SetText(newSubject) + msgEl := xmpp.NewElementName("message").SetType("groupchat").SetID(uuid.New().String()) + msgEl.AppendElement(subjectEl) + + s.sendMessageToRoom(ctx, room, occ.OccupantJID, msgEl) +} + +// isDeclineInvitation returns true if the message stanza is declining invitation to a room +func isDeclineInvitation(message *xmpp.Message) bool { + x := message.Elements().Child("x") + if x == nil || x.Namespace() != mucNamespaceUser { + return false + } + decline := x.Elements().Child("decline") + if decline == nil { + return false + } + return true +} + +// declineInvitation handles the message stanza declining invite to a room +func (s *Muc) declineInvitation(ctx context.Context, room *mucmodel.Room, message *xmpp.Message) { + if !room.UserIsInvited(message.FromJID().ToBareJID()) { + _ = s.router.Route(ctx, message.ForbiddenError()) + return + } + + room.DeleteInvite(message.FromJID().ToBareJID()) + s.repRoom.UpsertRoom(ctx, room) + + msg := getDeclineStanza(room, message) + _ = s.router.Route(ctx, msg) +} + +// isInvite returns true if message is a room invite mediated by a room +func isInvite(message *xmpp.Message) bool { + x := message.Elements().Child("x") + if x == nil || x.Namespace() != mucNamespaceUser { + return false + } + invite := x.Elements().Child("invite") + if invite == nil { + return false + } + return true +} + +// inviteUser handles the message stanza inviting a user to the room +func (s *Muc) inviteUser(ctx context.Context, room *mucmodel.Room, message *xmpp.Message) { + if errStanza := s.userHasVoice(ctx, room, message.FromJID(), message); errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + occ, errStanza := s.getOccupantFromStanza(ctx, room, message) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + if !room.Config.AllowInvites || (!room.Config.Open && !occ.IsModerator()) { + _ = s.router.Route(ctx, message.ForbiddenError()) + return + } + + // add to the list of invited users + invJID := getInvitedUserJID(message) + err := room.InviteUser(invJID) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, message.InternalServerError()) + } + s.repRoom.UpsertRoom(ctx, room) + + s.forwardInviteToUser(ctx, room, message) +} + +func (s *Muc) forwardInviteToUser(ctx context.Context, room *mucmodel.Room, message *xmpp.Message) { + inviteFrom := message.FromJID() + inviteTo := getInvitedUserJID(message) + + msg := getInvitationStanza(room, inviteFrom, inviteTo, message) + _ = s.router.Route(ctx, msg) +} + +// sendPm handles the message stanza which is of the type "chat" +func (s *Muc) sendPM(ctx context.Context, room *mucmodel.Room, message *xmpp.Message) { + // private message should be addressed to a particular occupant, not the whole room + if !message.ToJID().IsFull() { + _ = s.router.Route(ctx, message.BadRequestError()) + return + } + + // check if user is allowed to send the pm + if errStanza := s.userCanPMOccupant(ctx, room, message.FromJID(), message.ToJID(), message); errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + // send the PM + senderJID, ok := room.GetOccupantJID(message.FromJID().ToBareJID()) + if !ok { + _ = s.router.Route(ctx, message.ForbiddenError()) + return + } + + msgBody := message.Elements().Child("body") + if msgBody == nil { + _ = s.router.Route(ctx, message.BadRequestError()) + return + } + + s.messageOccupant(ctx, message.ToJID(), &senderJID, msgBody, message.ID(), true) +} + +// userCanPMOccupant returns true if user has permission to send the pm, error stanza otherwise +func (s *Muc) userCanPMOccupant(ctx context.Context, room *mucmodel.Room, usrJID, occJID *jid.JID, message *xmpp.Message) xmpp.Stanza { + // check if user can send private messages in this room + usrOccJID, ok := room.GetOccupantJID(usrJID.ToBareJID()) + if !ok { + return message.NotAcceptableError() + } + + usrOcc, err := s.repOccupant.FetchOccupant(ctx, &usrOccJID) + if err != nil || usrOcc == nil { + return message.InternalServerError() + } + + if !room.Config.OccupantCanSendPM(usrOcc) { + return message.NotAcceptableError() + } + + // check if the target occupant exists + occ, err := s.repOccupant.FetchOccupant(ctx, occJID) + if err != nil || occ == nil { + return message.ItemNotFoundError() + } + + // make sure the target occupant is in the same room + if occJID.ToBareJID().String() != room.RoomJID.String() { + return message.NotAcceptableError() + } + + return nil +} + +// messageEveryone handles the message stanza of the type "groupchat" +func (s *Muc) messageEveryone(ctx context.Context, room *mucmodel.Room, message *xmpp.Message) { + // the groupmessage should be addressed to the whole room, not a particular occupant + if message.ToJID().IsFull() { + _ = s.router.Route(ctx, message.BadRequestError()) + return + } + + // check if user is allowed to send a groupchat message + if errStanza := s.userHasVoice(ctx, room, message.FromJID(), message); errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + sendersOccupantJID, ok := room.GetOccupantJID(message.FromJID().ToBareJID()) + if !ok { + _ = s.router.Route(ctx, message.ForbiddenError()) + return + } + + msgBody := message.Elements().Child("body") + if msgBody == nil { + _ = s.router.Route(ctx, message.BadRequestError()) + return + } + + for _, occJID := range room.GetAllOccupantJIDs() { + s.messageOccupant(ctx, &occJID, &sendersOccupantJID, msgBody, message.ID(), false) + } +} + +// userHasVoice returns null if user is allowed to speak in the room, error stanza otherwise +func (s *Muc) userHasVoice(ctx context.Context, room *mucmodel.Room, userJID *jid.JID, + message *xmpp.Message) xmpp.Stanza { + // user has to be occupant of the room + occJID, ok := room.GetOccupantJID(userJID.ToBareJID()) + if !ok { + return message.NotAcceptableError() + } + + occ, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + log.Error(err) + return message.InternalServerError() + } + + if room.Config.Moderated && occ.IsVisitor() { + return message.ForbiddenError() + } + + return nil +} + +func (s *Muc) messageOccupant(ctx context.Context, occJID, senderJID *jid.JID, + body xmpp.XElement, id string, private bool) { + occupant, err := s.repOccupant.FetchOccupant(ctx, occJID) + if err != nil { + log.Error(err) + return + } + + msgEl := getMessageElement(body, id, private) + _ = s.sendMessageToOccupant(ctx, occupant, senderJID, msgEl) +} diff --git a/module/xep0045/message_test.go b/module/xep0045/message_test.go new file mode 100644 index 000000000..0ab579825 --- /dev/null +++ b/module/xep0045/message_test.go @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/pborman/uuid" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_VoiceRequestAndApproval(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + mock.occ.SetRole("visitor") + mock.muc.repOccupant.UpsertOccupant(nil, mock.occ) + mock.room.Config.Moderated = true + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + requestForm := &xep0004.DataForm{ + Type: xep0004.Submit, + } + requestForm.Fields = append(requestForm.Fields, xep0004.Field{ + Type: xep0004.ListSingle, + Var: "muc#role", + Label: "Requested role", + Values: []string{"participant"}, + }) + msgEl := xmpp.NewElementName("message").AppendElement(requestForm.Element()) + msg, _ := xmpp.NewMessageFromElement(msgEl, mock.occFullJID, mock.room.RoomJID) + + mock.muc.voiceRequest(nil, mock.room, msg) + + approvalMessage := mock.ownerStm.ReceiveElement() + require.Equal(t, approvalMessage.From(), mock.room.RoomJID.String()) + formEl := approvalMessage.Elements().Child("x") + require.NotNil(t, formEl) + approvalForm, err := xep0004.NewFormFromElement(formEl) + require.Nil(t, err) + require.Equal(t, approvalForm.Type, xep0004.Form) + approvalForm.Type = xep0004.Submit + for i, field := range approvalForm.Fields { + if field.Var == "muc#request_allow" { + approvalForm.Fields[i].Values = []string{"true"} + } + } + apMsgEl := xmpp.NewElementName("message").AppendElement(approvalForm.Element()) + apMsg, _ := xmpp.NewMessageFromElement(apMsgEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.voiceRequest(nil, mock.room, apMsg) + + ackOcc := mock.occStm.ReceiveElement() + require.Equal(t, ackOcc.From(), mock.occ.OccupantJID.String()) + itemEl := ackOcc.Elements().Child("x").Elements().Child("item") + require.NotNil(t, itemEl) + require.Equal(t, itemEl.Attributes().Get("role"), "participant") + + occ, _ := mock.muc.repOccupant.FetchOccupant(nil, mock.occ.OccupantJID) + require.True(t, occ.IsParticipant()) +} + +func TestXEP0045_ChangeSubject(t *testing.T) { + mock := setupTestRoomAndOwner() + + subjectEl := xmpp.NewElementName("subject").SetText("new subject") + msgEl := xmpp.NewElementName("message").SetType("groupchat") + msgEl.AppendElement(subjectEl) + msg, _ := xmpp.NewMessageFromElement(msgEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.changeSubject(nil, mock.room, msg) + + ack := mock.ownerStm.ReceiveElement() + require.Equal(t, ack.Type(), "groupchat") + newSubject := ack.Elements().Child("subject") + require.NotNil(t, newSubject) + require.Equal(t, newSubject.Text(), "new subject") + + updatedRoom, _ := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + require.Equal(t, updatedRoom.Subject, "new subject") +} + +func TestXEP0045_DeclineInvite(t *testing.T) { + mock := setupTestRoomAndOwner() + invitedUserJID, _ := jid.New("ortuman", "jackal.im", "balcony", true) + mock.room.InviteUser(invitedUserJID.ToBareJID()) + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + // user declines the invitation + reason := xmpp.NewElementName("reason").SetText("Sorry, not for me!") + invite := xmpp.NewElementName("decline") + invite.SetAttribute("to", mock.owner.BareJID.String()).AppendElement(reason) + x := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(invite) + m := xmpp.NewElementName("message").SetID("id-decline").AppendElement(x) + msg, _ := xmpp.NewMessageFromElement(m, invitedUserJID, mock.room.RoomJID) + + require.True(t, isDeclineInvitation(msg)) + mock.muc.declineInvitation(nil, mock.room, msg) + + decline := mock.ownerStm.ReceiveElement() + require.Equal(t, decline.From(), mock.room.RoomJID.String()) + room, _ := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + require.False(t, room.UserIsInvited(invitedUserJID.ToBareJID())) +} + +func TestXEP0045_SendInvite(t *testing.T) { + mock := setupTestRoomAndOwner() + mock.room.Config.AllowInvites = true + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + invitedUserJID, _ := jid.New("ortuman", "jackal.im", "balcony", true) + invStm := stream.NewMockC2S("id-2", invitedUserJID) + invStm.SetPresence(xmpp.NewPresence(invitedUserJID.ToBareJID(), invitedUserJID, + xmpp.AvailableType)) + mock.muc.router.Bind(context.Background(), invStm) + + // user is not already invited + require.False(t, mock.room.UserIsInvited(invitedUserJID.ToBareJID())) + + // owner sends the invitation + reason := xmpp.NewElementName("reason").SetText("Join me!") + invite := xmpp.NewElementName("invite") + invite.SetAttribute("to", invitedUserJID.ToBareJID().String()) + invite.AppendElement(reason) + x := xmpp.NewElementNamespace("x", mucNamespaceUser).AppendElement(invite) + m := xmpp.NewElementName("message").SetID("id-invite").AppendElement(x) + msg, err := xmpp.NewMessageFromElement(m, mock.ownerFullJID, mock.room.RoomJID) + require.Nil(t, err) + + require.True(t, isInvite(msg)) + mock.muc.inviteUser(context.Background(), mock.room, msg) + + inviteStanza := invStm.ReceiveElement() + require.Equal(t, inviteStanza.From(), mock.room.RoomJID.String()) + + updatedRoom, _ := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + require.True(t, updatedRoom.UserIsInvited(invitedUserJID.ToBareJID())) +} + +func TestXEP0045_MessageEveryone(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + + // owner sends the group message + body := xmpp.NewElementName("body").SetText("Hello world!") + msgEl := xmpp.NewMessageType(uuid.New(), "groupchat").AppendElement(body) + msg, _ := xmpp.NewMessageFromElement(msgEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.messageEveryone(context.Background(), mock.room, msg) + + regMsg := mock.occStm.ReceiveElement() + ownerMsg := mock.ownerStm.ReceiveElement() + + require.Equal(t, regMsg.Type(), "groupchat") + msgTxt := regMsg.Elements().Child("body").Text() + require.Equal(t, msgTxt, "Hello world!") + + require.Equal(t, ownerMsg.Type(), "groupchat") + msgTxt = ownerMsg.Elements().Child("body").Text() + require.Equal(t, msgTxt, "Hello world!") +} + +func TestXEP0045_SendPM(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + mock.room.Config.SetWhoCanSendPM("all") + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + // owner sends the private message + body := xmpp.NewElementName("body").SetText("Hello ortuman!") + msgEl := xmpp.NewMessageType(uuid.New(), "chat").AppendElement(body) + m, _ := xmpp.NewMessageFromElement(msgEl, mock.ownerFullJID, mock.occ.OccupantJID) + + mock.muc.sendPM(context.Background(), mock.room, m) + + regMsg := mock.occStm.ReceiveElement() + require.Equal(t, regMsg.Type(), "chat") + msgTxt := regMsg.Elements().Child("body").Text() + require.Equal(t, msgTxt, "Hello ortuman!") +} diff --git a/module/xep0045/muc.go b/module/xep0045/muc.go new file mode 100644 index 000000000..aced1f11f --- /dev/null +++ b/module/xep0045/muc.go @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "sync" + + "github.com/ortuman/jackal/log" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0030" + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/util/runqueue" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +// Muc represents the multi user chat service as defined in xep-0045 +type Muc struct { + cfg *Config + disco *xep0030.DiscoInfo + repRoom repository.Room + repOccupant repository.Occupant + allRooms []jid.JID + router router.Router + runQueue *runqueue.RunQueue + mu sync.RWMutex +} + +func New(cfg *Config, disco *xep0030.DiscoInfo, router router.Router, repRoom repository.Room, + repOccupant repository.Occupant) *Muc { + // muc service needs a separate hostname + if len(cfg.MucHost) == 0 || router.Hosts().IsLocalHost(cfg.MucHost) { + log.Errorf("Muc service could not be started - invalid hostname") + return nil + } + s := &Muc{ + cfg: cfg, + disco: disco, + repRoom: repRoom, + repOccupant: repOccupant, + router: router, + runQueue: runqueue.New("muc"), + } + // add the muc service hostname to the hosts + router.Hosts().AddMucHostname(cfg.MucHost) + if disco != nil { + setupDiscoService(cfg, disco, s) + } + return s +} + +// MatchesIQ is accepting all IQs aimed at the conference service +func (s *Muc) MatchesIQ(iq *xmpp.IQ) bool { + return s.router.Hosts().IsConferenceHost(iq.ToJID().Domain()) +} + +// ProcessIQ queues the iq stanzas directed to the conference service +func (s *Muc) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { + s.runQueue.Run(func() { + s.processIQ(ctx, iq) + }) +} + +func (s *Muc) processIQ(ctx context.Context, iq *xmpp.IQ) { + roomJID := iq.ToJID().ToBareJID() + room, err := s.repRoom.FetchRoom(ctx, roomJID) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, iq.InternalServerError()) + return + } + if room == nil { + _ = s.router.Route(ctx, iq.ItemNotFoundError()) + return + } + + query := iq.Elements().Child("query") + if query == nil { + _ = s.router.Route(ctx, iq.BadRequestError()) + return + } + iqDomain := query.Namespace() + + switch iqDomain { + case mucNamespaceOwner: + s.processIQOwner(ctx, room, iq) + case mucNamespaceAdmin: + s.processIQAdmin(ctx, room, iq) + default: + _ = s.router.Route(ctx, iq.BadRequestError()) + } +} + +func (s *Muc) processIQOwner(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + switch { + case isIQForInstantRoomCreate(iq): + s.createInstantRoom(ctx, room, iq) + case isIQForRoomConfigRequest(iq): + s.sendRoomConfiguration(ctx, room, iq) + case isIQForRoomConfigSubmission(iq): + s.processRoomConfiguration(ctx, room, iq) + case isIQForRoomDestroy(iq): + s.destroyRoom(ctx, room, iq) + default: + _ = s.router.Route(ctx, iq.BadRequestError()) + } +} + +func (s *Muc) processIQAdmin(ctx context.Context, room *mucmodel.Room, iq *xmpp.IQ) { + switch { + case iq.IsGet(): + s.getOccupantList(ctx, room, iq) + case iq.IsSet(): + s.modifyOccupantList(ctx, room, iq) + default: + _ = s.router.Route(ctx, iq.BadRequestError()) + } +} + +// ProcessPresence queues the presence stanzas directed to the conference service +func (s *Muc) ProcessPresence(ctx context.Context, presence *xmpp.Presence) { + s.runQueue.Run(func() { + s.processPresence(ctx, presence) + }) +} + +func (s *Muc) processPresence(ctx context.Context, presence *xmpp.Presence) { + roomJID := presence.ToJID().ToBareJID() + room, err := s.repRoom.FetchRoom(ctx, roomJID) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, presence.InternalServerError()) + return + } + + switch { + case isPresenceToEnterRoom(presence): + s.enterRoom(ctx, room, presence) + case isChangingStatus(presence): + s.changeStatus(ctx, room, presence) + case presence.IsUnavailable(): + s.exitRoom(ctx, room, presence) + default: + s.changeNickname(ctx, room, presence) + } +} + +// ProcessMessage queues the message stanzas directed to the conference service +func (s *Muc) ProcessMessage(ctx context.Context, message *xmpp.Message) { + s.runQueue.Run(func() { + s.processMessage(ctx, message) + }) +} + +func (s *Muc) processMessage(ctx context.Context, message *xmpp.Message) { + roomJID := message.ToJID().ToBareJID() + room, err := s.repRoom.FetchRoom(ctx, roomJID) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, message.InternalServerError()) + return + } + if room == nil { + _ = s.router.Route(ctx, message.ItemNotFoundError()) + return + } + + switch { + case isInvite(message): + s.inviteUser(ctx, room, message) + case isDeclineInvitation(message): + s.declineInvitation(ctx, room, message) + case isVoiceRequest(message): + s.voiceRequest(ctx, room, message) + case message.IsGroupChat() && message.Elements().Child("subject") != nil: + s.changeSubject(ctx, room, message) + case message.IsGroupChat(): + s.messageEveryone(ctx, room, message) + case message.IsChat() || message.Type() == "": + s.sendPM(ctx, room, message) + default: + _ = s.router.Route(ctx, message.BadRequestError()) + } +} + +func (s *Muc) GetMucHostname() string { + return s.cfg.MucHost +} + +// GetDefaultRoomConfig returns the room configuration as specified in the server's yaml +func (s *Muc) GetDefaultRoomConfig() *mucmodel.RoomConfig { + conf := s.cfg.RoomDefaults + return &conf +} + +func (s *Muc) Shutdown() error { + c := make(chan struct{}) + s.runQueue.Stop(func() { close(c) }) + <-c + return nil +} diff --git a/module/xep0045/muc_test.go b/module/xep0045/muc_test.go new file mode 100644 index 000000000..a117fbc61 --- /dev/null +++ b/module/xep0045/muc_test.go @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/pborman/uuid" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_NewService(t *testing.T) { + r, c := setupTest("jackal.im") + + failedMuc := New(&Config{MucHost: "jackal.im"}, nil, r, c.Room(), c.Occupant()) + require.Nil(t, failedMuc) + + muc := New(&Config{MucHost: "conference.jackal.im"}, nil, r, c.Room(), c.Occupant()) + defer func() { _ = muc.Shutdown() }() + + require.False(t, muc.router.Hosts().IsConferenceHost("jackal.im")) + require.True(t, muc.router.Hosts().IsConferenceHost("conference.jackal.im")) + + require.Equal(t, muc.GetMucHostname(), "conference.jackal.im") +} + +func TestXEP0045_ProcessIQInstantRoom(t *testing.T) { + mock := setupMockMucService() + userJID, _ := jid.New("ortuman", "jackal.im", "balcony", true) + occupantJID, _ := jid.New("room", "conference.jackal.im", "nick", true) + stm := stream.NewMockC2S(uuid.New(), userJID) + stm.SetPresence(xmpp.NewPresence(userJID.ToBareJID(), userJID, xmpp.AvailableType)) + mock.muc.router.Bind(nil, stm) + + // creating a locked room + err := mock.muc.newRoom(nil, userJID, occupantJID) + require.Nil(t, err) + room, err := mock.muc.repRoom.FetchRoom(nil, occupantJID.ToBareJID()) + require.Nil(t, err) + require.NotNil(t, room) + // NOTE(mmalesev) uncomment once this is changed in the room create function + //require.True(t, room.Locked) + + // instant room create iq + x := xmpp.NewElementNamespace("x", xep0004.FormNamespace) + x.SetAttribute("type", "submit") + query := xmpp.NewElementNamespace("query", mucNamespaceOwner).AppendElement(x) + iq := xmpp.NewElementName("iq").SetID("create1").SetType("set") + iq.AppendElement(query) + request, err := xmpp.NewIQFromElement(iq, userJID, occupantJID) + require.Nil(t, err) + + // sending an instant room request into the stream + require.True(t, mock.muc.MatchesIQ(request)) + mock.muc.ProcessIQ(context.Background(), request) + + // receive the instant room creation confirmation + ack := stm.ReceiveElement() + require.Equal(t, ack, request.ResultIQ()) + + // the room should be unlocked now + updatedRoom, err := mock.muc.repRoom.FetchRoom(nil, occupantJID.ToBareJID()) + require.False(t, updatedRoom.Locked) +} + +func TestXEP0045_ProcessPresenceNewRoom(t *testing.T) { + mock := setupMockMucService() + + from, _ := jid.New("ortuman", "jackal.im", "balcony", true) + to, _ := jid.New("room", "conference.jackal.im", "nick", true) + + stm := stream.NewMockC2S(uuid.New(), from) + stm.SetPresence(xmpp.NewPresence(from.ToBareJID(), from, xmpp.AvailableType)) + mock.muc.router.Bind(context.Background(), stm) + + e := xmpp.NewElementNamespace("x", mucNamespace) + p := xmpp.NewElementName("presence").AppendElement(e) + presence, _ := xmpp.NewPresenceFromElement(p, from, to) + + mock.muc.ProcessPresence(context.Background(), presence) + + // sender receives the appropriate response + ack := stm.ReceiveElement() + require.Equal(t, ack.String(), getAckStanza(to, from).String()) + + // the room is created + roomMem, err := mock.muc.repRoom.FetchRoom(nil, to.ToBareJID()) + require.Nil(t, err) + require.NotNil(t, roomMem) + require.Equal(t, to.ToBareJID().String(), roomMem.RoomJID.String()) + require.Equal(t, mock.muc.allRooms[0].String(), to.ToBareJID().String()) + oMem, err := mock.muc.repOccupant.FetchOccupant(nil, to) + require.Nil(t, err) + require.NotNil(t, oMem) + require.Equal(t, from.ToBareJID().String(), oMem.BareJID.String()) + + // the room is locked + // NOTE(mmalesev) uncomment once this is changed in the room create function + //require.True(t, roomMem.Locked) +} + +func TestXEP0045_ProcessMessageMsgEveryone(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + + // owner sends the group message + body := xmpp.NewElementName("body").SetText("Hello world!") + msgEl := xmpp.NewMessageType(uuid.New(), "groupchat").AppendElement(body) + msg, _ := xmpp.NewMessageFromElement(msgEl, mock.ownerFullJID, mock.room.RoomJID) + + mock.muc.ProcessMessage(context.Background(), msg) + + regMsg := mock.occStm.ReceiveElement() + ownerMsg := mock.ownerStm.ReceiveElement() + + require.Equal(t, regMsg.Type(), "groupchat") + msgTxt := regMsg.Elements().Child("body").Text() + require.Equal(t, msgTxt, "Hello world!") + + require.Equal(t, ownerMsg.Type(), "groupchat") + msgTxt = ownerMsg.Elements().Child("body").Text() + require.Equal(t, msgTxt, "Hello world!") +} diff --git a/module/xep0045/occupant.go b/module/xep0045/occupant.go new file mode 100644 index 000000000..c28faaf61 --- /dev/null +++ b/module/xep0045/occupant.go @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "fmt" + + "github.com/ortuman/jackal/log" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +const ( + // instructions in the voice request form + voiceRequestInstructions = "To approve this voice request select the checkbox and click OK." +) + +// newOccupant takes the full user JID and occupant JID and creates a room occupant +func (s *Muc) newOccupant(ctx context.Context, userJID, occJID *jid.JID) (*mucmodel.Occupant, error) { + // check if the occupant already exists + o, err := s.repOccupant.FetchOccupant(ctx, occJID) + switch { + case err != nil: + return nil, err + case o == nil: + // if the occupant with this occJID does not exist, create it + o, err = mucmodel.NewOccupant(occJID, userJID.ToBareJID()) + if err != nil { + return nil, err + } + case userJID.ToBareJID().String() != o.BareJID.String(): + // user with the given userJID is trying to use occJID of another user + return nil, fmt.Errorf("xep0045: Can't use another user's occupant nick") + } + + if !userJID.IsFull() { + return nil, fmt.Errorf("xep0045: User jid has to specify the resource") + + } + o.AddResource(userJID.Resource()) + + err = s.repOccupant.UpsertOccupant(ctx, o) + if err != nil { + return nil, err + } + + return o, nil +} + +// createOwner returns a room occupant with the owner affiliation +func (s *Muc) createOwner(ctx context.Context, userJID, occJID *jid.JID) (*mucmodel.Occupant, error) { + o, err := s.newOccupant(ctx, userJID, occJID) + if err != nil { + return nil, err + } + o.SetAffiliation("owner") + err = s.repOccupant.UpsertOccupant(ctx, o) + if err != nil { + return nil, err + } + return o, nil +} + +// getOccupantFromStanza takes xmpp stanza and returns the room occupant associated with the sender +func (s *Muc) getOccupantFromStanza(ctx context.Context, room *mucmodel.Room, + stanza xmpp.Stanza) (*mucmodel.Occupant, xmpp.Stanza) { + occJID, ok := room.GetOccupantJID(stanza.FromJID().ToBareJID()) + if !ok { + return nil, xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrForbidden, nil) + } + + occ, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + log.Error(err) + return nil, xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrInternalServerError, nil) + } + return occ, nil +} + +// getOwnerFromIQ takes iq stanza and returns occupant instance if the sender is an owner +func (s *Muc) getOwnerFromIQ(ctx context.Context, room *mucmodel.Room, + iq *xmpp.IQ) (*mucmodel.Occupant, xmpp.Stanza) { + occ, errStanza := s.getOccupantFromStanza(ctx, room, iq) + if errStanza != nil { + return nil, errStanza + } + + if !occ.IsOwner() { + return nil, iq.ForbiddenError() + } + + return occ, nil +} + +// getOccupantsByRole returns a slice of the occupants with the given role in the room +func (s *Muc) getOccupantsByRole(ctx context.Context, room *mucmodel.Room, + sender *mucmodel.Occupant, role string) ([]*mucmodel.Occupant, error) { + if !sender.IsModerator() { + return nil, fmt.Errorf("xep0045: only mods can retrive the list of %ss", role) + } + res := make([]*mucmodel.Occupant, 0) + for _, occJID := range room.GetAllOccupantJIDs() { + o, _ := s.repOccupant.FetchOccupant(ctx, &occJID) + if o.GetRole() == role { + res = append(res, o) + } + } + return res, nil +} + +// getOccupantsByRole returns a slice of the occupants with the given affiliation in the room +func (s *Muc) getOccupantsByAffiliation(ctx context.Context, room *mucmodel.Room, + sender *mucmodel.Occupant, aff string) ([]*mucmodel.Occupant, error) { + switch aff { + case "outcast", "member": + if !sender.IsAdmin() && !sender.IsOwner() { + return nil, fmt.Errorf("xep0045: only admins and owners can get %ss", aff) + } + case "owner", "admin": + if !sender.IsOwner() { + return nil, fmt.Errorf("xep0045: only owners can retrive the %ss", aff) + } + default: + return nil, fmt.Errorf("xep0045: unknown affiliation") + } + + res := make([]*mucmodel.Occupant, 0) + for _, occJID := range room.GetAllOccupantJIDs() { + o, _ := s.repOccupant.FetchOccupant(ctx, &occJID) + if o.GetAffiliation() == aff { + res = append(res, o) + } + } + return res, nil +} + +// sendPresenceToOccupant sends the given presence element to every resource the occupant uses +func (s *Muc) sendPresenceToOccupant(ctx context.Context, o *mucmodel.Occupant, + from *jid.JID, presenceEl *xmpp.Element) error { + for _, resource := range o.GetAllResources() { + to := addResourceToBareJID(o.BareJID, resource) + p, err := xmpp.NewPresenceFromElement(presenceEl, from, to) + if err != nil { + return err + } + err = s.router.Route(ctx, p) + if err != nil { + return err + } + } + return nil +} + +// sendPresenceToOccupant sends the given message element to every resource the occupant uses +func (s *Muc) sendMessageToOccupant(ctx context.Context, o *mucmodel.Occupant, + from *jid.JID, messageEl *xmpp.Element) error { + for _, resource := range o.GetAllResources() { + to := addResourceToBareJID(o.BareJID, resource) + message, err := xmpp.NewMessageFromElement(messageEl, from, to) + if err != nil { + return err + } + err = s.router.Route(ctx, message) + if err != nil { + return err + } + } + return nil +} + +func (s *Muc) getVoiceRequestForm(ctx context.Context, o *mucmodel.Occupant) *xep0004.DataForm { + form := &xep0004.DataForm{ + Type: xep0004.Form, + Title: "Voice request", + Instructions: voiceRequestInstructions, + } + form.Fields = append(form.Fields, xep0004.Field{ + Var: xep0004.FormType, + Type: xep0004.Hidden, + Values: []string{"http://jabber.org/protocol/muc#request"}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: "muc#role", + Type: xep0004.ListSingle, + Label: "Requested role", + Values: []string{"participant"}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: "muc#jid", + Type: xep0004.JidSingle, + Label: "User ID", + Values: []string{o.BareJID.String()}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: "muc#roomnick", + Type: xep0004.TextSingle, + Label: "Room nickname", + Values: []string{o.OccupantJID.Resource()}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: "muc#request_allow", + Type: xep0004.Boolean, + Label: "Grant voice to this person?", + Values: []string{"false"}, + }) + return form +} diff --git a/module/xep0045/occupant_test.go b/module/xep0045/occupant_test.go new file mode 100644 index 000000000..998063d19 --- /dev/null +++ b/module/xep0045/occupant_test.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "testing" + + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_CreateOwner(t *testing.T) { + r, c := setupTest("jackal.im") + muc := New(&Config{MucHost: "conference.jackal.im"}, nil, r, c.Room(), c.Occupant()) + defer func() { _ = muc.Shutdown() }() + + occJID, _ := jid.New("room", "conference.jackal.im", "nick", true) + fullJID, _ := jid.New("ortuman", "jackal.im", "balcony", true) + o, err := muc.createOwner(nil, fullJID, occJID) + require.Nil(t, err) + + oMem, err := muc.repOccupant.FetchOccupant(nil, occJID) + require.Nil(t, err) + require.NotNil(t, oMem) + assert.EqualValues(t, o, oMem) +} + +func TestXEP0045_CreateOccupant(t *testing.T) { + r, c := setupTest("jackal.im") + muc := New(&Config{MucHost: "conference.jackal.im"}, nil, r, c.Room(), c.Occupant()) + defer func() { _ = muc.Shutdown() }() + + occJID, _ := jid.New("room", "conference.jackal.im", "nick", true) + fullJID, _ := jid.New("ortuman", "jackal.im", "balcony", true) + o, err := muc.newOccupant(nil, fullJID, occJID) + require.Nil(t, err) + + oMem, err := muc.repOccupant.FetchOccupant(nil, occJID) + require.Nil(t, err) + require.NotNil(t, oMem) + assert.EqualValues(t, o, oMem) + + errUsr, _ := jid.New("milos", "jackal.im", "laptop", true) + errOcc, err := muc.newOccupant(nil, errUsr, occJID) + require.NotNil(t, err) + require.Nil(t, errOcc) +} diff --git a/module/xep0045/presence.go b/module/xep0045/presence.go new file mode 100644 index 000000000..4a4ece2d0 --- /dev/null +++ b/module/xep0045/presence.go @@ -0,0 +1,409 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "strings" + + "github.com/ortuman/jackal/log" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +// exitRoom handles the presence stanza of the type unavailable +func (s *Muc) exitRoom(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) { + o, errStanza := s.getOccupantFromStanza(ctx, room, presence) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + // check if the user is trying to make himself unavailable, or someone else + if o.OccupantJID.String() != presence.ToJID().String() { + _ = s.router.Route(ctx, presence.ForbiddenError()) + return + } + + s.occupantExitsRoom(ctx, room, o) + + err := s.sendOccExitedRoom(ctx, o, room) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, presence.InternalServerError()) + } +} + +func (s *Muc) occupantExitsRoom(ctx context.Context, room *mucmodel.Room, o *mucmodel.Occupant) { + // if the user has no affiliation then its occupant JID is not reserved + if o.HasNoAffiliation() { + s.repOccupant.DeleteOccupant(ctx, o.OccupantJID) + } else { + o.SetRole("") + s.repOccupant.UpsertOccupant(ctx, o) + } + + room.OccupantLeft(o) + s.repRoom.UpsertRoom(ctx, room) + + if !room.Config.Persistent && room.IsEmpty() { + s.deleteRoom(ctx, room) + } +} + +func (s *Muc) sendOccExitedRoom(ctx context.Context, occExiting *mucmodel.Occupant, + room *mucmodel.Room) error { + resultPresence := xmpp.NewElementName("presence").SetType("unavailable") + + for _, occJID := range room.GetAllOccupantJIDs() { + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return err + } + xEl := newOccupantAffiliationRoleElement(occExiting, + room.Config.OccupantCanDiscoverRealJID(o), false) + if occJID.String() == occExiting.OccupantJID.String() { + xEl.AppendElement(newStatusElement("110")) + } + resultPresence.AppendElement(xEl) + err = s.sendPresenceToOccupant(ctx, o, occExiting.OccupantJID, resultPresence) + } + return nil +} + +// isChangingStatus returns true if the presence is updating the occupant's status +func isChangingStatus(presence *xmpp.Presence) bool { + status := presence.Elements().Child("status") + show := presence.Elements().Child("show") + if status == nil && show == nil { + return false + } + return true +} + +func (s *Muc) changeStatus(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) { + o, errStanza := s.getOccupantFromStanza(ctx, room, presence) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + // user can only change his own status + if o.OccupantJID.String() != presence.ToJID().String() { + _ = s.router.Route(ctx, presence.ForbiddenError()) + return + } + + if o.IsVisitor() { + _ = s.router.Route(ctx, presence.ForbiddenError()) + return + } + + show := presence.Elements().Child("show") + status := presence.Elements().Child("status") + err := s.sendStatus(ctx, room, o, show, status) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, presence.InternalServerError()) + return + } +} + +func (s *Muc) sendStatus(ctx context.Context, room *mucmodel.Room, sender *mucmodel.Occupant, + show, status xmpp.XElement) error { + presence := xmpp.NewElementName("presence").AppendElement(show).AppendElement(status) + + for _, occJID := range room.GetAllOccupantJIDs() { + if occJID.String() == sender.OccupantJID.String() { + continue + } + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return err + } + xEl := newOccupantAffiliationRoleElement(sender, + room.Config.OccupantCanDiscoverRealJID(o), false) + presence.AppendElement(xEl) + err = s.sendPresenceToOccupant(ctx, o, sender.OccupantJID, presence) + if err != nil { + return err + } + } + + return nil +} + +// changeNickname processes presence that is changing the sender's nickname in the room +func (s *Muc) changeNickname(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) { + if errStanza := s.newNickIsAvailable(ctx, presence); errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + + occ, errStanza := s.getOccupantFromStanza(ctx, room, presence) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return + } + oldOccJID := occ.OccupantJID + + occ.SetAffiliation("") + room.OccupantLeft(occ) + s.repOccupant.DeleteOccupant(ctx, oldOccJID) + + occ.OccupantJID = presence.ToJID() + room.AddOccupant(occ) + s.repOccupant.UpsertOccupant(ctx, occ) + s.repRoom.UpsertRoom(ctx, room) + + // send the unavailable and presence stanzas to the room members + err := s.sendNickChangeAck(ctx, room, occ, oldOccJID) + if err != nil { + log.Error(err) + _ = s.router.Route(ctx, presence.InternalServerError()) + return + } +} + +func (s *Muc) sendNickChangeAck(ctx context.Context, room *mucmodel.Room, + newOcc *mucmodel.Occupant, oldJID *jid.JID) error { + for _, occJID := range room.GetAllOccupantJIDs() { + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return err + } + selfNotifying := (occJID.String() == newOcc.OccupantJID.String()) + getRealJID := room.Config.OccupantCanDiscoverRealJID(o) + + unavailableEl := getOccupantUnavailableElement(newOcc, selfNotifying, getRealJID) + err = s.sendPresenceToOccupant(ctx, o, oldJID, unavailableEl) + if err != nil { + return err + } + + statusEl := getOccupantStatusElement(newOcc, selfNotifying, getRealJID) + err = s.sendPresenceToOccupant(ctx, o, newOcc.OccupantJID, statusEl) + if err != nil { + return err + } + } + return nil +} + +func (s *Muc) newNickIsAvailable(ctx context.Context, presence *xmpp.Presence) xmpp.Stanza { + o, err := s.repOccupant.FetchOccupant(ctx, presence.ToJID()) + if err != nil { + log.Error(err) + return presence.InternalServerError() + } + if o != nil { + return presence.ConflictError() + } + return nil +} + +// isPresenceToEnterRoom returns true if presence is for entering a room (or creating a new one) +func isPresenceToEnterRoom(presence *xmpp.Presence) bool { + if presence.Type() != "" { + return false + } + x := presence.Elements().ChildNamespace("x", mucNamespace) + if x == nil || len(strings.TrimSpace(x.Text())) != 0 || x.Elements().Count() != 0 { + return false + } + return true +} + +// enterRoom puts the sender into an existing room or makes him an owner of a new room +func (s *Muc) enterRoom(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) { + if room == nil { + err := s.newRoomRequest(ctx, room, presence) + if err != nil { + _ = s.router.Route(ctx, presence.InternalServerError()) + return + } + } else { + err := s.joinExistingRoom(ctx, room, presence) + if err != nil { + _ = s.router.Route(ctx, presence.InternalServerError()) + return + } + } +} + +func (s *Muc) newRoomRequest(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) error { + err := s.newRoom(ctx, presence.FromJID(), presence.ToJID()) + if err != nil { + return err + } + + el := getAckStanza(presence.ToJID(), presence.FromJID()) + _ = s.router.Route(ctx, el) + return nil +} + +func (s *Muc) joinExistingRoom(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) error { + ok, err := s.occupantCanEnterRoom(ctx, room, presence) + if !ok || err != nil { + return err + } + + occ, err := s.newOccupant(ctx, presence.FromJID(), presence.ToJID()) + if err != nil { + return err + } + + err = s.AddOccupantToRoom(ctx, room, occ) + if err != nil { + return err + } + + err = s.sendEnterRoomAck(ctx, room, presence) + if err != nil { + return err + } + + return nil +} + +func (s *Muc) occupantCanEnterRoom(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) (bool, error) { + userJID := presence.FromJID() + occupantJID := presence.ToJID() + + occupant, err := s.repOccupant.FetchOccupant(ctx, occupantJID) + if err != nil { + return false, err + } + + // no one can enter a locked room + if room.Locked { + _ = s.router.Route(ctx, presence.ItemNotFoundError()) + return false, nil + } + + // nick for the occupant has to be provided + if !occupantJID.IsFull() { + _ = s.router.Route(ctx, presence.JidMalformedError()) + return false, nil + } + + errStanza := checkNicknameConflict(room, occupant, userJID, occupantJID, presence) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return false, nil + } + + errStanza = checkPassword(room, presence) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return false, nil + } + + errStanza = checkOccupantMembership(room, occupant, userJID, presence) + if errStanza != nil { + _ = s.router.Route(ctx, errStanza) + return false, nil + } + + // check if this occupant is banned + if occupant != nil && occupant.IsOutcast() { + _ = s.router.Route(ctx, presence.ForbiddenError()) + return false, nil + } + + // check if the maximum number of occupants is reached + if occupant != nil && !occupant.IsOwner() && !occupant.IsAdmin() && room.IsFull() { + _ = s.router.Route(ctx, presence.ServiceUnavailableError()) + return false, nil + } + + return true, nil +} + +func checkNicknameConflict(room *mucmodel.Room, newOccupant *mucmodel.Occupant, + userJID, occupantJID *jid.JID, presence *xmpp.Presence) xmpp.Stanza { + // check if the user, who is already in the room, is entering with a different nickname + oJID, ok := room.GetOccupantJID(userJID.ToBareJID()) + if ok && oJID.String() != occupantJID.String() { + return presence.NotAcceptableError() + } + + // check if another user is trying to use an already occupied nickname + if newOccupant != nil && newOccupant.BareJID.String() != userJID.ToBareJID().String() { + return presence.ConflictError() + } + + return nil +} + +func checkPassword(room *mucmodel.Room, presence *xmpp.Presence) xmpp.Stanza { + // if password required, make sure that it is correctly supplied + if room.Config.PwdProtected { + pwd := getPasswordFromPresence(presence) + if pwd != room.Config.Password { + return presence.NotAuthorizedError() + } + } + return nil +} + +func checkOccupantMembership(room *mucmodel.Room, occupant *mucmodel.Occupant, userJID *jid.JID, + presence *xmpp.Presence) xmpp.Stanza { + // if members-only room, check that the occupant is a member + if !room.Config.Open { + if room.UserIsInvited(userJID.ToBareJID()) { + return nil + } + if occupant != nil && !occupant.HasNoAffiliation() { + return nil + } + return presence.RegistrationRequiredError() + } + return nil +} + +func (s *Muc) sendEnterRoomAck(ctx context.Context, room *mucmodel.Room, presence *xmpp.Presence) error { + newOccupant, err := s.repOccupant.FetchOccupant(ctx, presence.ToJID()) + if err != nil { + return err + } + + for _, occJID := range room.GetAllOccupantJIDs() { + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return err + } + + // skip the user entering the room + if o.BareJID.String() == newOccupant.BareJID.String() { + continue + } + + s.sendPresenceAboutNewOccupant(ctx, room, newOccupant, o) + } + + // final notification to the new occupant with status codes (self-presence) + spEl := getOccupantSelfPresenceElement(newOccupant, room.Config.NonAnonymous, presence.ID()) + s.sendPresenceToOccupant(ctx, newOccupant, newOccupant.OccupantJID, spEl) + + // send the room subject + subjEl := getRoomSubjectElement(room.Subject) + s.sendMessageToOccupant(ctx, newOccupant, room.RoomJID, subjEl) + + return nil +} + +func (s *Muc) sendPresenceAboutNewOccupant(ctx context.Context, room *mucmodel.Room, + newOccupant, o *mucmodel.Occupant) { + // notify the new occupant of the existing occupant + oStatusEl := getOccupantStatusElement(o, false, room.Config.OccupantCanDiscoverRealJID(newOccupant)) + s.sendPresenceToOccupant(ctx, newOccupant, o.OccupantJID, oStatusEl) + + // notify the existing occupant of the new occupant + newStatusEl := getOccupantStatusElement(newOccupant, false, room.Config.OccupantCanDiscoverRealJID(o)) + s.sendPresenceToOccupant(ctx, o, newOccupant.OccupantJID, newStatusEl) +} diff --git a/module/xep0045/presence_test.go b/module/xep0045/presence_test.go new file mode 100644 index 000000000..0d5f4b4ef --- /dev/null +++ b/module/xep0045/presence_test.go @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_ExitRoom(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + + // presence for exiting the room + p := xmpp.NewElementName("presence").SetType("unavailable") + status := xmpp.NewElementName("status").SetText("bye!") + p.AppendElement(status) + presence, _ := xmpp.NewPresenceFromElement(p, mock.occFullJID, mock.occ.OccupantJID) + + mock.muc.exitRoom(nil, mock.room, presence) + + ack := mock.ownerStm.ReceiveElement() + require.Equal(t, ack.Type(), "unavailable") + + exists, err := mock.muc.repOccupant.OccupantExists(nil, mock.occ.OccupantJID) + require.Nil(t, err) + require.False(t, exists) + + room, _ := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + require.NotNil(t, room) + require.False(t, room.UserIsInRoom(mock.occ.BareJID)) +} + +func TestXEP0045_ChangeStatus(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + + // presence to change the nick + p := xmpp.NewElementName("presence") + show := xmpp.NewElementName("show").SetText("xa") + p.AppendElement(show) + status := xmpp.NewElementName("status").SetText("my new status") + p.AppendElement(status) + presence, _ := xmpp.NewPresenceFromElement(p, mock.ownerFullJID, mock.owner.OccupantJID) + require.True(t, isChangingStatus(presence)) + + mock.muc.changeStatus(nil, mock.room, presence) + + // the user receives status update + statusStanza := mock.occStm.ReceiveElement() + require.NotNil(t, statusStanza) + require.Equal(t, statusStanza.Elements().Child("status").Text(), "my new status") + require.NotNil(t, statusStanza.Elements().Child("show")) +} + +func TestXEP0045_ChangeNickname(t *testing.T) { + mock := setupTestRoomAndOwnerAndOcc() + newOccJID, _ := jid.New("room", "conference.jackal.im", "newnick", true) + + // presence to change the nick + p := xmpp.NewElementName("presence") + presence, _ := xmpp.NewPresenceFromElement(p, mock.ownerFullJID, newOccJID) + require.NotNil(t, presence) + + mock.muc.changeNickname(nil, mock.room, presence) + + // the user receives unavailable stanza + ackUnavailable := mock.occStm.ReceiveElement() + require.NotNil(t, ackUnavailable) + require.Equal(t, ackUnavailable.Type(), "unavailable") + + // the user receives presence stanza + ackPresence := mock.occStm.ReceiveElement() + require.NotNil(t, ackPresence) + require.Equal(t, ackPresence.From(), newOccJID.String()) + + // old nick is deleted + occBefore, err := mock.muc.repOccupant.FetchOccupant(nil, mock.owner.OccupantJID) + require.Nil(t, err) + require.Nil(t, occBefore) + + // new nick is added + jidAfter, _ := mock.room.GetOccupantJID(mock.owner.BareJID) + require.NotNil(t, jidAfter) + require.Equal(t, jidAfter.String(), newOccJID.String()) + occAfter, err := mock.muc.repOccupant.FetchOccupant(nil, newOccJID) + require.Nil(t, err) + require.NotNil(t, occAfter) + require.Equal(t, occAfter.BareJID.String(), mock.owner.BareJID.String()) +} + +func TestXEP0045_JoinExistingRoom(t *testing.T) { + mock := setupTestRoomAndOwner() + mock.room.Config.PwdProtected = true + mock.room.Config.Password = "secret" + mock.room.Config.Open = false + mock.room.Subject = "Room for testing" + + from, _ := jid.New("ortuman", "jackal.im", "balcony", true) + to, _ := jid.New("room", "conference.jackal.im", "nick", true) + mock.room.InviteUser(from.ToBareJID()) + mock.muc.repRoom.UpsertRoom(nil, mock.room) + + newStm := stream.NewMockC2S(uuid.New(), from) + newStm.SetPresence(xmpp.NewPresence(from.ToBareJID(), from, xmpp.AvailableType)) + mock.muc.router.Bind(context.Background(), newStm) + + pwd := xmpp.NewElementName("password").SetText("secret") + e := xmpp.NewElementNamespace("x", mucNamespace).AppendElement(pwd) + p := xmpp.NewElementName("presence").AppendElement(e) + presence, _ := xmpp.NewPresenceFromElement(p, from, to) + + mock.muc.enterRoom(context.Background(), mock.room, presence) + + // sender receives the appropriate response + ack := newStm.ReceiveElement() + require.Equal(t, ack.From(), mock.owner.OccupantJID.String()) + + // owner receives the appropriate response + ownerAck := mock.ownerStm.ReceiveElement() + require.Equal(t, ownerAck.From(), to.String()) + + // sender receives the self-presence + ackSelf := newStm.ReceiveElement() + require.Equal(t, ackSelf.From(), to.String()) + + // sender receives the room subject + ackSubj := newStm.ReceiveElement() + require.NotNil(t, ackSubj.Elements().Child("subject").Text(), "Room for testing") + + // user is in the room + occ, err := mock.muc.repOccupant.FetchOccupant(context.Background(), to) + require.Nil(t, err) + require.NotNil(t, occ) +} + +func TestXEP0045_NewRoomRequest(t *testing.T) { + mock := setupMockMucService() + from, _ := jid.New("ortuman", "jackal.im", "balcony", true) + to, _ := jid.New("room", "conference.jackal.im", "nick", true) + + stm := stream.NewMockC2S(uuid.New(), from) + stm.SetPresence(xmpp.NewPresence(from.ToBareJID(), from, xmpp.AvailableType)) + mock.muc.router.Bind(context.Background(), stm) + + e := xmpp.NewElementNamespace("x", mucNamespace) + p := xmpp.NewElementName("presence").AppendElement(e) + presence, _ := xmpp.NewPresenceFromElement(p, from, to) + + mock.muc.enterRoom(context.Background(), nil, presence) + + // sender receives the appropriate response + ack := stm.ReceiveElement() + require.Equal(t, ack.String(), getAckStanza(to, from).String()) + + // the room is created + roomMem, err := mock.muc.repRoom.FetchRoom(nil, to.ToBareJID()) + require.Nil(t, err) + require.NotNil(t, roomMem) + require.Equal(t, to.ToBareJID().String(), roomMem.RoomJID.String()) + require.Equal(t, mock.muc.allRooms[0].String(), to.ToBareJID().String()) + oMem, err := mock.muc.repOccupant.FetchOccupant(nil, to) + require.Nil(t, err) + require.NotNil(t, oMem) + require.Equal(t, from.ToBareJID().String(), oMem.BareJID.String()) + //make sure the room is locked + // NOTE(mmalesev) uncomment once this is changed in the room create function + //require.True(t, roomMem.Locked) +} + +func TestXEP0045_OccupantCanEnterRoom(t *testing.T) { + mock := setupTestRoomAndOwner() + + // presence stanza for entering the room correctly + pwd := xmpp.NewElementName("password").SetText("secret") + e := xmpp.NewElementNamespace("x", mucNamespace).AppendElement(pwd) + p := xmpp.NewElementName("presence").AppendElement(e) + presence, _ := xmpp.NewPresenceFromElement(p, mock.ownerFullJID, mock.owner.OccupantJID) + + // owner can enter + canEnter, err := mock.muc.occupantCanEnterRoom(context.Background(), mock.room, presence) + require.Nil(t, err) + require.True(t, canEnter) + + // lock the room, no one should be able to enter now + mock.room.Locked = true + canEnter, err = mock.muc.occupantCanEnterRoom(context.Background(), mock.room, presence) + require.Nil(t, err) + require.False(t, canEnter) + ack := mock.ownerStm.ReceiveElement() + assert.EqualValues(t, ack, presence.ItemNotFoundError()) + room, _ := mock.muc.repRoom.FetchRoom(nil, mock.room.RoomJID) + room.Locked = false +} diff --git a/module/xep0045/room.go b/module/xep0045/room.go new file mode 100644 index 000000000..ce20edbd1 --- /dev/null +++ b/module/xep0045/room.go @@ -0,0 +1,391 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "context" + "strconv" + + "github.com/ortuman/jackal/log" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +const ( + // configuration instructions for a locked room + initialRoomConfigInstructions = ` +Your room has been created! +To accept the default configuration, click OK. To +select a different configuration, please complete +this form. +` + + // configuration instructions for an unlocked room + roomConfigInstructions = "Complete this form to modify the configuration of your room" + + // fields that can be configured in a room configuration form + ConfigName = "muc#roomconfig_roomname" + ConfigDesc = "muc#roomconfig_roomdesc" + ConfigAllowPM = "muc#roomconfig_allowpm" + ConfigAllowInvites = "muc#roomconfig_allowinvites" + ConfigChangeSubj = "muc#roomconfig_changesubject" + ConfigMemberList = "muc#roomconfig_getmemberlist" + ConfigLanguage = "muc#roomconfig_lang" + ConfigMaxUsers = "muc#roomconfig_maxusers" + ConfigMembersOnly = "muc#roomconfig_membersonly" + ConfigModerated = "muc#roomconfig_moderatedroom" + ConfigPwdProtected = "muc#roomconfig_passwordprotectedroom" + ConfigPersistent = "muc#roomconfig_persistentroom" + ConfigPublic = "muc#roomconfig_publicroom" + ConfigPwd = "muc#roomconfig_roomsecret" + ConfigWhoIs = "muc#roomconfig_whois" +) + +// newRoom saves a new room into storage with given owner information +func (s *Muc) newRoom(ctx context.Context, ownerFullJID, ownerOccJID *jid.JID) error { + owner, err := s.createOwner(ctx, ownerFullJID, ownerOccJID) + if err != nil { + return err + } + + roomJID := ownerOccJID.ToBareJID() + _, err = s.createRoom(ctx, roomJID, owner) + if err != nil { + return err + } + + s.mu.Lock() + s.allRooms = append(s.allRooms, *roomJID) + s.mu.Unlock() + + return nil +} + +func (s *Muc) createRoom(ctx context.Context, roomJID *jid.JID, owner *mucmodel.Occupant) (*mucmodel.Room, error) { + r := &mucmodel.Room{ + Config: s.GetDefaultRoomConfig(), + Name: roomJID.Node(), + RoomJID: roomJID, + // NOTE(mmalesev) Locked should be true here, however some clients (e.g. profanity) have issues unlocking, so + // this is a temporary hack + Locked: false, + } + + err := s.AddOccupantToRoom(ctx, r, owner) + if err != nil { + return nil, err + } + return r, nil +} + +// AddOccupantToRoom updates the room in storage with a given occupant +func (s *Muc) AddOccupantToRoom(ctx context.Context, room *mucmodel.Room, occupant *mucmodel.Occupant) error { + room.AddOccupant(occupant) + + err := s.repOccupant.UpsertOccupant(ctx, occupant) + if err != nil { + return err + } + + return s.repRoom.UpsertRoom(ctx, room) +} + +// getRoomConfigForm returns the configuration form for the room +func (s *Muc) getRoomConfigForm(ctx context.Context, room *mucmodel.Room) *xep0004.DataForm { + form := &xep0004.DataForm{ + Type: xep0004.Form, + Title: "Configuration for " + room.Name + "Room", + Instructions: getRoomConfigInstructions(room), + } + form.Fields = append(form.Fields, xep0004.Field{ + Var: xep0004.FormType, + Type: xep0004.Hidden, + Values: []string{"http://jabber.org/protocol/muc#roomconfig"}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigName, + Type: xep0004.TextSingle, + Label: "Natural-Language Room Name", + Values: []string{room.Name}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigDesc, + Type: xep0004.TextSingle, + Label: "Short description of Room", + Values: []string{room.Desc}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigLanguage, + Type: xep0004.TextSingle, + Label: "Natural Language for Room Discussion", + Values: []string{room.Language}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigChangeSubj, + Type: xep0004.Boolean, + Label: "Allow Occupants to Change Subject?", + Values: []string{boolToStr(room.Config.AllowSubjChange)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigAllowInvites, + Type: xep0004.Boolean, + Label: "Allow Occupants to Invite Others?", + Values: []string{boolToStr(room.Config.AllowInvites)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigMembersOnly, + Type: xep0004.Boolean, + Label: "Make Room Members Only?", + Values: []string{boolToStr(!room.Config.Open)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigModerated, + Type: xep0004.Boolean, + Label: "Make Room Moderated?", + Values: []string{boolToStr(room.Config.Moderated)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigPersistent, + Type: xep0004.Boolean, + Label: "Make Room Persistent?", + Values: []string{boolToStr(room.Config.Persistent)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigPublic, + Type: xep0004.Boolean, + Label: "Make Room Publicly Searchable?", + Values: []string{boolToStr(room.Config.Public)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigWhoIs, + Type: xep0004.Boolean, + Label: "Make room NonAnonymous? (show real JIDs)", + Values: []string{boolToStr(room.Config.NonAnonymous)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigPwdProtected, + Type: xep0004.Boolean, + Label: "Password Required to Enter?", + Values: []string{boolToStr(room.Config.PwdProtected)}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Type: xep0004.Fixed, + Values: []string{"If the password is required to enter the room, specify it below"}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigPwd, + Type: xep0004.TextSingle, + Label: "Password", + Values: []string{room.Config.Password}, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigAllowPM, + Type: xep0004.ListSingle, + Label: "Roles that May Send Private Messages", + Values: []string{room.Config.WhoCanSendPM()}, + Options: []xep0004.Option{ + xep0004.Option{Label: "Anyone", Value: mucmodel.All}, + xep0004.Option{Label: "Moderators Only", Value: mucmodel.Moderators}, + xep0004.Option{Label: "Nobody", Value: mucmodel.None}, + }, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigMemberList, + Type: xep0004.ListSingle, + Label: "Who Can Retrieve Member List", + Values: []string{room.Config.WhoCanGetMemberList()}, + Options: []xep0004.Option{ + xep0004.Option{Label: "Anyone", Value: mucmodel.All}, + xep0004.Option{Label: "Moderators Only", Value: mucmodel.Moderators}, + xep0004.Option{Label: "Nobody", Value: mucmodel.None}, + }, + }) + form.Fields = append(form.Fields, xep0004.Field{ + Var: ConfigMaxUsers, + Type: xep0004.ListSingle, + Label: "Maximum Number of Occupants (-1 for unlimited)", + Values: []string{strconv.Itoa(room.Config.MaxOccCnt)}, + Options: []xep0004.Option{ + xep0004.Option{Label: "10", Value: "10"}, + xep0004.Option{Label: "20", Value: "20"}, + xep0004.Option{Label: "30", Value: "30"}, + xep0004.Option{Label: "50", Value: "50"}, + xep0004.Option{Label: "100", Value: "100"}, + xep0004.Option{Label: "500", Value: "100"}, + xep0004.Option{Label: "-1", Value: "-1"}, + }, + }) + return form +} + +func getRoomConfigInstructions(room *mucmodel.Room) (instr string) { + if room.Locked { + instr = initialRoomConfigInstructions + } else { + instr = roomConfigInstructions + } + return +} + +// updateRoomWithForm updates the room information with the information submitted in the form +func (s *Muc) updateRoomWithForm(ctx context.Context, room *mucmodel.Room, form *xep0004.DataForm) (updatedAnonimity, ok bool) { + ok = true + for _, field := range form.Fields { + if len(field.Values) == 0 { + continue + } + switch field.Var { + case ConfigName: + room.Name = field.Values[0] + case ConfigDesc: + room.Desc = field.Values[0] + case ConfigLanguage: + room.Language = field.Values[0] + case ConfigChangeSubj: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.AllowSubjChange = n + case ConfigAllowInvites: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.AllowInvites = n + case ConfigMembersOnly: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.Open = !n + case ConfigModerated: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.Moderated = n + case ConfigPersistent: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.Persistent = n + case ConfigPublic: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.Public = n + case ConfigPwdProtected: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.PwdProtected = n + case ConfigPwd: + room.Config.Password = field.Values[0] + case ConfigAllowPM: + err := room.Config.SetWhoCanSendPM(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + case ConfigMemberList: + err := room.Config.SetWhoCanGetMemberList(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + case ConfigWhoIs: + n, err := strconv.ParseBool(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + if room.Config.NonAnonymous != n { + updatedAnonimity = true + } + room.Config.NonAnonymous = n + case ConfigMaxUsers: + n, err := strconv.Atoi(field.Values[0]) + if err != nil { + log.Error(err) + ok = false + } + room.Config.MaxOccCnt = n + } + } + + // the password has to be specified if it is required to enter the room + if room.Config.PwdProtected && room.Config.Password == "" { + ok = false + } + + // if the configForm was valid, update the room + if ok { + room.Locked = false + s.repRoom.UpsertRoom(ctx, room) + } + + return +} + +func boolToStr(value bool) string { + if value { + return "1" + } + return "0" +} + +// sendPresenceToRoom sends the presence element to every occupant in the room +func (s *Muc) sendPresenceToRoom(ctx context.Context, r *mucmodel.Room, from *jid.JID, + presenceEl *xmpp.Element) error { + for _, occJID := range r.GetAllOccupantJIDs() { + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return err + } + err = s.sendPresenceToOccupant(ctx, o, from, presenceEl) + if err != nil { + return err + } + } + return nil +} + +// sendPresenceToRoom sends the message element to every occupant in the room +func (s *Muc) sendMessageToRoom(ctx context.Context, r *mucmodel.Room, from *jid.JID, + messageEl *xmpp.Element) error { + for _, occJID := range r.GetAllOccupantJIDs() { + o, err := s.repOccupant.FetchOccupant(ctx, &occJID) + if err != nil { + return err + } + err = s.sendMessageToOccupant(ctx, o, from, messageEl) + if err != nil { + return err + } + } + return nil +} + +// deleteRoom deletes all occupants in the room, and then the room itself +func (s *Muc) deleteRoom(ctx context.Context, r *mucmodel.Room) { + for _, occJID := range r.GetAllOccupantJIDs() { + s.repOccupant.DeleteOccupant(ctx, &occJID) + } + s.repRoom.DeleteRoom(ctx, r.RoomJID) +} diff --git a/module/xep0045/room_test.go b/module/xep0045/room_test.go new file mode 100644 index 000000000..5ffe460e7 --- /dev/null +++ b/module/xep0045/room_test.go @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0045 + +import ( + "testing" + + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestXEP0045_CreateRoom(t *testing.T) { + r, c := setupTest("jackal.im") + muc := New(&Config{MucHost: "conference.jackal.im"}, nil, r, c.Room(), c.Occupant()) + defer func() { _ = muc.Shutdown() }() + + occJID, _ := jid.New("room", "conference.jackal.im", "nick", true) + fullJID, _ := jid.New("ortuman", "jackal.im", "balcony", true) + o, err := muc.createOwner(nil, fullJID, occJID) + require.Nil(t, err) + + roomJID, _ := jid.New("room", "conference.jackal.im", "", true) + room, err := muc.createRoom(nil, roomJID, o) + require.Nil(t, err) + require.NotNil(t, room) + require.True(t, room.UserIsInRoom(fullJID.ToBareJID())) + jidInRoom, _ := room.GetOccupantJID(fullJID.ToBareJID()) + assert.EqualValues(t, jidInRoom, *occJID) + + roomMem, err := c.Room().FetchRoom(nil, roomJID) + require.Nil(t, err) + require.Equal(t, roomJID.String(), roomMem.RoomJID.String()) +} + +func TestXEP0045_NewRoom(t *testing.T) { + r, c := setupTest("jackal.im") + muc := New(&Config{MucHost: "conference.jackal.im"}, nil, r, c.Room(), c.Occupant()) + defer func() { _ = muc.Shutdown() }() + + from, _ := jid.New("ortuman", "jackal.im", "balcony", true) + to, _ := jid.New("room", "conference.jackal.im", "nick", true) + err := muc.newRoom(nil, from, to) + require.Nil(t, err) + + roomMem, err := c.Room().FetchRoom(nil, to.ToBareJID()) + require.Nil(t, err) + require.NotNil(t, roomMem) + assert.EqualValues(t, to.ToBareJID(), roomMem.RoomJID) + toRoom, _ := roomMem.GetOccupantJID(from.ToBareJID()) + assert.EqualValues(t, *to, toRoom) + require.Equal(t, muc.allRooms[0].String(), to.ToBareJID().String()) + + oMem, err := c.Occupant().FetchOccupant(nil, to) + require.Nil(t, err) + require.NotNil(t, oMem) + assert.EqualValues(t, from.ToBareJID(), oMem.BareJID) +} diff --git a/module/xep0049/private.go b/module/xep0049/private.go index 50b8cddf7..6ec64e445 100644 --- a/module/xep0049/private.go +++ b/module/xep0049/private.go @@ -6,12 +6,13 @@ package xep0049 import ( + "context" "strings" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" ) @@ -19,30 +20,30 @@ const privateNamespace = "jabber:iq:private" // Private represents a private storage server stream module. type Private struct { - router *router.Router + router router.Router runQueue *runqueue.RunQueue + rep repository.Private } // New returns a private storage IQ handler module. -func New(router *router.Router) *Private { +func New(router router.Router, privRep repository.Private) *Private { x := &Private{ router: router, runQueue: runqueue.New("xep0049"), + rep: privRep, } return x } -// MatchesIQ returns whether or not an IQ should be -// processed by the private storage module. +// MatchesIQ returns whether or not an IQ should be processed by the private storage module. func (x *Private) MatchesIQ(iq *xmpp.IQ) bool { return iq.Elements().ChildNamespace("query", privateNamespace) != nil } -// ProcessIQ processes a private storage IQ -// taking according actions over the associated stream -func (x *Private) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes a private storage IQ taking according actions over the associated stream. +func (x *Private) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - x.processIQ(iq) + x.processIQ(ctx, iq) }) } @@ -54,28 +55,28 @@ func (x *Private) Shutdown() error { return nil } -func (x *Private) processIQ(iq *xmpp.IQ) { +func (x *Private) processIQ(ctx context.Context, iq *xmpp.IQ) { q := iq.Elements().ChildNamespace("query", privateNamespace) fromJid := iq.FromJID() toJid := iq.ToJID() validTo := toJid.IsServer() || toJid.Node() == fromJid.Node() if !validTo { - _ = x.router.Route(iq.ForbiddenError()) + _ = x.router.Route(ctx, iq.ForbiddenError()) return } if iq.IsGet() { - x.getPrivate(iq, q) + x.getPrivate(ctx, iq, q) } else if iq.IsSet() { - x.setPrivate(iq, q) + x.setPrivate(ctx, iq, q) } else { - _ = x.router.Route(iq.BadRequestError()) + _ = x.router.Route(ctx, iq.BadRequestError()) return } } -func (x *Private) getPrivate(iq *xmpp.IQ, q xmpp.XElement) { +func (x *Private) getPrivate(ctx context.Context, iq *xmpp.IQ, q xmpp.XElement) { if q.Elements().Count() != 1 { - _ = x.router.Route(iq.NotAcceptableError()) + _ = x.router.Route(ctx, iq.NotAcceptableError()) return } privElem := q.Elements().All()[0] @@ -83,16 +84,16 @@ func (x *Private) getPrivate(iq *xmpp.IQ, q xmpp.XElement) { isValidNS := x.isValidNamespace(privNS) if privElem.Elements().Count() > 0 || !isValidNS { - _ = x.router.Route(iq.NotAcceptableError()) + _ = x.router.Route(ctx, iq.NotAcceptableError()) return } fromJID := iq.FromJID() log.Infof("retrieving private element. ns: %s... (%s/%s)", privNS, fromJID.Node(), fromJID.Resource()) - privElements, err := storage.FetchPrivateXML(privNS, fromJID.Node()) + privElements, err := x.rep.FetchPrivateXML(ctx, privNS, fromJID.Node()) if err != nil { log.Error(err) - _ = x.router.Route(iq.InternalServerError()) + _ = x.router.Route(ctx, iq.InternalServerError()) return } res := iq.ResultIQ() @@ -104,41 +105,41 @@ func (x *Private) getPrivate(iq *xmpp.IQ, q xmpp.XElement) { } res.AppendElement(query) - _ = x.router.Route(res) + _ = x.router.Route(ctx, res) } -func (x *Private) setPrivate(iq *xmpp.IQ, q xmpp.XElement) { +func (x *Private) setPrivate(ctx context.Context, iq *xmpp.IQ, q xmpp.XElement) { nsElements := map[string][]xmpp.XElement{} - for _, privElement := range q.Elements().All() { - ns := privElement.Namespace() + for _, prvElement := range q.Elements().All() { + ns := prvElement.Namespace() if len(ns) == 0 { - _ = x.router.Route(iq.BadRequestError()) + _ = x.router.Route(ctx, iq.BadRequestError()) return } - if !x.isValidNamespace(privElement.Namespace()) { - _ = x.router.Route(iq.NotAcceptableError()) + if !x.isValidNamespace(prvElement.Namespace()) { + _ = x.router.Route(ctx, iq.NotAcceptableError()) return } - elems := nsElements[ns] - if elems == nil { - elems = []xmpp.XElement{privElement} + elements := nsElements[ns] + if elements == nil { + elements = []xmpp.XElement{prvElement} } else { - elems = append(elems, privElement) + elements = append(elements, prvElement) } - nsElements[ns] = elems + nsElements[ns] = elements } fromJID := iq.FromJID() for ns, elements := range nsElements { log.Infof("saving private element. ns: %s... (%s/%s)", ns, fromJID.Node(), fromJID.Resource()) - if err := storage.InsertOrUpdatePrivateXML(elements, ns, fromJID.Node()); err != nil { + if err := x.rep.UpsertPrivateXML(ctx, elements, ns, fromJID.Node()); err != nil { log.Error(err) - _ = x.router.Route(iq.InternalServerError()) + _ = x.router.Route(ctx, iq.InternalServerError()) return } } - _ = x.router.Route(iq.ResultIQ()) + _ = x.router.Route(ctx, iq.ResultIQ()) } func (x *Private) isValidNamespace(ns string) bool { diff --git a/module/xep0049/private_test.go b/module/xep0049/private_test.go index be05c9fc7..19d57c901 100644 --- a/module/xep0049/private_test.go +++ b/module/xep0049/private_test.go @@ -6,12 +6,16 @@ package xep0049 import ( + "context" "crypto/tls" "testing" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -24,10 +28,10 @@ func TestXEP0049_Matching(t *testing.T) { j2, _ := jid.New("romeo", "jackal.im", "balcony", true) stm := stream.NewMockC2S("abcd", j1) - defer stm.Disconnect(nil) + defer stm.Disconnect(context.Background(), nil) - x := New(nil) - defer x.Shutdown() + x := New(nil, nil) + defer func() { _ = x.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq.SetFromJID(j1) @@ -39,17 +43,18 @@ func TestXEP0049_Matching(t *testing.T) { } func TestXEP0049_InvalidIQ(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("romeo", "jackal.im", "balcony", true) stm := stream.NewMockC2S("abcd", j1) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) - x := New(r) - defer x.Shutdown() + x := New(r, s) + defer func() { _ = x.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq.SetFromJID(j1) @@ -57,52 +62,53 @@ func TestXEP0049_InvalidIQ(t *testing.T) { q := xmpp.NewElementNamespace("query", privateNamespace) iq.AppendElement(q) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrForbidden.Error(), elem.Error().Elements().All()[0].Name()) iq.SetType(xmpp.ResultType) iq.SetToJID(j1.ToBareJID()) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) iq.SetType(xmpp.GetType) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAcceptable.Error(), elem.Error().Elements().All()[0].Name()) exodus := xmpp.NewElementNamespace("exodus", "exodus:ns") exodus.AppendElement(xmpp.NewElementName("exodus2")) q.AppendElement(exodus) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAcceptable.Error(), elem.Error().Elements().All()[0].Name()) exodus.ClearElements() exodus.SetNamespace("jabber:client") iq.SetType(xmpp.SetType) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAcceptable.Error(), elem.Error().Elements().All()[0].Name()) exodus.SetNamespace("") - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) } func TestXEP0049_SetAndGetPrivate(t *testing.T) { - r, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S("abcd", j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) - x := New(r) - defer x.Shutdown() + x := New(r, s) + defer func() { _ = x.Shutdown() }() iqID := uuid.New() iq := xmpp.NewIQType(iqID, xmpp.SetType) @@ -117,14 +123,14 @@ func TestXEP0049_SetAndGetPrivate(t *testing.T) { q.AppendElement(exodus2) // set error - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() // set success - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ResultType, elem.Type()) require.Equal(t, iqID, elem.ID()) @@ -133,14 +139,14 @@ func TestXEP0049_SetAndGetPrivate(t *testing.T) { q.RemoveElements("exodus2") iq.SetType(xmpp.GetType) - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() // get success - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ResultType, elem.Type()) require.Equal(t, iqID, elem.ID()) @@ -151,7 +157,7 @@ func TestXEP0049_SetAndGetPrivate(t *testing.T) { // get non existing exodus1.SetNamespace("exodus:ns:2") - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ResultType, elem.Type()) require.Equal(t, iqID, elem.ID()) @@ -160,13 +166,13 @@ func TestXEP0049_SetAndGetPrivate(t *testing.T) { require.Equal(t, "exodus:ns:2", q3.Elements().All()[0].Namespace()) } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, repository.Private) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + s := memorystorage.NewPrivate() + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r, s } diff --git a/module/xep0054/vcard.go b/module/xep0054/vcard.go index bcd0e5332..c1fd8a292 100644 --- a/module/xep0054/vcard.go +++ b/module/xep0054/vcard.go @@ -6,11 +6,13 @@ package xep0054 import ( + "context" + "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module/xep0030" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" ) @@ -18,15 +20,17 @@ const vCardNamespace = "vcard-temp" // VCard represents a vCard server stream module. type VCard struct { - router *router.Router + router router.Router runQueue *runqueue.RunQueue + rep repository.VCard } // New returns a vCard IQ handler module. -func New(disco *xep0030.DiscoInfo, router *router.Router) *VCard { +func New(disco *xep0030.DiscoInfo, router router.Router, rep repository.VCard) *VCard { v := &VCard{ router: router, runQueue: runqueue.New("xep0054"), + rep: rep, } if disco != nil { disco.RegisterServerFeature(vCardNamespace) @@ -41,11 +45,10 @@ func (x *VCard) MatchesIQ(iq *xmpp.IQ) bool { return (iq.IsGet() || iq.IsSet()) && iq.Elements().ChildNamespace("vCard", vCardNamespace) != nil } -// ProcessIQ processes a vCard IQ taking according actions -// over the associated stream. -func (x *VCard) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes a vCard IQ taking according actions over the associated stream. +func (x *VCard) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - x.processIQ(iq) + x.processIQ(ctx, iq) }) } @@ -57,33 +60,33 @@ func (x *VCard) Shutdown() error { return nil } -func (x *VCard) processIQ(iq *xmpp.IQ) { +func (x *VCard) processIQ(ctx context.Context, iq *xmpp.IQ) { vCard := iq.Elements().ChildNamespace("vCard", vCardNamespace) if vCard != nil { if iq.IsGet() { - x.getVCard(vCard, iq) + x.getVCard(ctx, vCard, iq) return } else if iq.IsSet() { - x.setVCard(vCard, iq) + x.setVCard(ctx, vCard, iq) return } } - _ = x.router.Route(iq.BadRequestError()) + _ = x.router.Route(ctx, iq.BadRequestError()) } -func (x *VCard) getVCard(vCard xmpp.XElement, iq *xmpp.IQ) { +func (x *VCard) getVCard(ctx context.Context, vCard xmpp.XElement, iq *xmpp.IQ) { if vCard.Elements().Count() > 0 { - _ = x.router.Route(iq.BadRequestError()) + _ = x.router.Route(ctx, iq.BadRequestError()) return } toJID := iq.ToJID() - resElem, err := storage.FetchVCard(toJID.Node()) + resElem, err := x.rep.FetchVCard(ctx, toJID.Node()) if err != nil { log.Errorf("%v", err) - _ = x.router.Route(iq.InternalServerError()) + _ = x.router.Route(ctx, iq.InternalServerError()) return } - log.Infof("retrieving vcard... (%s/%s)", toJID.Node(), toJID.Resource()) + log.Infof("retrieving vcard... (jid: %s)", toJID.String()) resultIQ := iq.ResultIQ() if resElem != nil { @@ -92,24 +95,24 @@ func (x *VCard) getVCard(vCard xmpp.XElement, iq *xmpp.IQ) { // empty vCard resultIQ.AppendElement(xmpp.NewElementNamespace("vCard", vCardNamespace)) } - _ = x.router.Route(resultIQ) + _ = x.router.Route(ctx, resultIQ) } -func (x *VCard) setVCard(vCard xmpp.XElement, iq *xmpp.IQ) { +func (x *VCard) setVCard(ctx context.Context, vCard xmpp.XElement, iq *xmpp.IQ) { fromJID := iq.FromJID() toJID := iq.ToJID() if toJID.IsServer() || (toJID.Node() == fromJID.Node()) { - log.Infof("saving vcard... (%s/%s)", toJID.Node(), toJID.Resource()) + log.Infof("saving vcard... (jid: %s)", toJID.String()) - err := storage.InsertOrUpdateVCard(vCard, toJID.Node()) + err := x.rep.UpsertVCard(ctx, vCard, toJID.Node()) if err != nil { log.Error(err) - _ = x.router.Route(iq.InternalServerError()) + _ = x.router.Route(ctx, iq.InternalServerError()) return } - _ = x.router.Route(iq.ResultIQ()) + _ = x.router.Route(ctx, iq.ResultIQ()) } else { - _ = x.router.Route(iq.ForbiddenError()) + _ = x.router.Route(ctx, iq.ForbiddenError()) } } diff --git a/module/xep0054/vcard_test.go b/module/xep0054/vcard_test.go index 7128054d0..5f38d0312 100644 --- a/module/xep0054/vcard_test.go +++ b/module/xep0054/vcard_test.go @@ -6,12 +6,15 @@ package xep0054 import ( + "context" "crypto/tls" "testing" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -22,8 +25,8 @@ import ( func TestXEP0054_Matching(t *testing.T) { j, _ := jid.New("ortuman", "jackal.im", "balcony", true) - x := New(nil, nil) - defer x.Shutdown() + x := New(nil, nil, nil) + defer func() { _ = x.Shutdown() }() // test MatchesIQ iqID := uuid.New() @@ -44,13 +47,14 @@ func TestXEP0054_Matching(t *testing.T) { } func TestXEP0054_Set(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S("abcd", j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) iqID := uuid.New() iq := xmpp.NewIQType(iqID, xmpp.SetType) @@ -58,10 +62,10 @@ func TestXEP0054_Set(t *testing.T) { iq.SetToJID(j.ToBareJID()) iq.AppendElement(testVCard()) - x := New(nil, r) - defer x.Shutdown() + x := New(nil, r, s) + defer func() { _ = x.Shutdown() }() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.NotNil(t, elem) require.Equal(t, xmpp.ResultType, elem.Type()) @@ -74,7 +78,7 @@ func TestXEP0054_Set(t *testing.T) { iq2.SetToJID(j.ToBareJID()) iq2.AppendElement(xmpp.NewElementNamespace("vCard", vCardNamespace)) - x.ProcessIQ(iq2) + x.ProcessIQ(context.Background(), iq2) elem = stm.ReceiveElement() require.NotNil(t, elem) require.Equal(t, xmpp.ResultType, elem.Type()) @@ -82,17 +86,18 @@ func TestXEP0054_Set(t *testing.T) { } func TestXEP0054_SetError(t *testing.T) { - r, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("romeo", "jackal.im", "garden", true) stm := stream.NewMockC2S("abcd", j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) - x := New(nil, r) - defer x.Shutdown() + x := New(nil, r, s) + defer func() { _ = x.Shutdown() }() // set other user vCard... iq := xmpp.NewIQType(uuid.New(), xmpp.SetType) @@ -100,43 +105,44 @@ func TestXEP0054_SetError(t *testing.T) { iq.SetToJID(j2.ToBareJID()) iq.AppendElement(testVCard()) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrForbidden.Error(), elem.Error().Elements().All()[0].Name()) // storage error - s.EnableMockedError() - defer s.DisableMockedError() + memorystorage.EnableMockedError() + defer memorystorage.DisableMockedError() iq2 := xmpp.NewIQType(uuid.New(), xmpp.SetType) iq2.SetFromJID(j) iq2.SetToJID(j.ToBareJID()) iq2.AppendElement(testVCard()) - x.ProcessIQ(iq2) + x.ProcessIQ(context.Background(), iq2) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) } func TestXEP0054_Get(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("romeo", "jackal.im", "garden", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) iqSet := xmpp.NewIQType(uuid.New(), xmpp.SetType) iqSet.SetFromJID(j) iqSet.SetToJID(j.ToBareJID()) iqSet.AppendElement(testVCard()) - x := New(nil, r) - defer x.Shutdown() + x := New(nil, r, s) + defer func() { _ = x.Shutdown() }() - x.ProcessIQ(iqSet) + x.ProcessIQ(context.Background(), iqSet) _ = stm.ReceiveElement() // wait until set... iqGetID := uuid.New() @@ -145,7 +151,7 @@ func TestXEP0054_Get(t *testing.T) { iqGet.SetToJID(j.ToBareJID()) iqGet.AppendElement(xmpp.NewElementNamespace("vCard", vCardNamespace)) - x.ProcessIQ(iqGet) + x.ProcessIQ(context.Background(), iqGet) elem := stm.ReceiveElement() require.NotNil(t, elem) vCard := elem.Elements().ChildNamespace("vCard", vCardNamespace) @@ -159,7 +165,7 @@ func TestXEP0054_Get(t *testing.T) { iqGet2.SetToJID(j2.ToBareJID()) iqGet2.AppendElement(xmpp.NewElementNamespace("vCard", vCardNamespace)) - x.ProcessIQ(iqGet2) + x.ProcessIQ(context.Background(), iqGet2) elem = stm.ReceiveElement() require.NotNil(t, elem) vCard = elem.Elements().ChildNamespace("vCard", vCardNamespace) @@ -167,23 +173,24 @@ func TestXEP0054_Get(t *testing.T) { } func TestXEP0054_GetError(t *testing.T) { - r, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S("abcd", j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) iqSet := xmpp.NewIQType(uuid.New(), xmpp.SetType) iqSet.SetFromJID(j) iqSet.SetToJID(j.ToBareJID()) iqSet.AppendElement(testVCard()) - x := New(nil, r) - defer x.Shutdown() + x := New(nil, r, s) + defer func() { _ = x.Shutdown() }() - x.ProcessIQ(iqSet) + x.ProcessIQ(context.Background(), iqSet) _ = stm.ReceiveElement() // wait until set... iqGetID := uuid.New() @@ -194,7 +201,7 @@ func TestXEP0054_GetError(t *testing.T) { vCard.AppendElement(xmpp.NewElementName("FN")) iqGet.AppendElement(vCard) - x.ProcessIQ(iqGet) + x.ProcessIQ(context.Background(), iqGet) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) @@ -204,10 +211,10 @@ func TestXEP0054_GetError(t *testing.T) { iqGet2.SetToJID(j.ToBareJID()) iqGet2.AppendElement(xmpp.NewElementNamespace("vCard", vCardNamespace)) - s.EnableMockedError() - defer s.DisableMockedError() + memorystorage.EnableMockedError() + defer memorystorage.DisableMockedError() - x.ProcessIQ(iqGet2) + x.ProcessIQ(context.Background(), iqGet2) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) } @@ -223,13 +230,13 @@ func testVCard() xmpp.XElement { return vCard } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, *memorystorage.VCard) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + s := memorystorage.NewVCard() + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r, s } diff --git a/module/xep0077/register.go b/module/xep0077/register.go index 63e87a390..0958f4c17 100644 --- a/module/xep0077/register.go +++ b/module/xep0077/register.go @@ -6,13 +6,15 @@ package xep0077 import ( + "context" + "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/model" "github.com/ortuman/jackal/module/xep0030" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) @@ -31,16 +33,18 @@ type Config struct { // Register represents an in-band server stream module. type Register struct { cfg *Config - router *router.Router + router router.Router runQueue *runqueue.RunQueue + rep repository.User } // New returns an in-band registration IQ handler. -func New(config *Config, disco *xep0030.DiscoInfo, router *router.Router) *Register { +func New(config *Config, disco *xep0030.DiscoInfo, router router.Router, userRep repository.User) *Register { r := &Register{ cfg: config, router: router, runQueue: runqueue.New("xep0077"), + rep: userRep, } if disco != nil { disco.RegisterServerFeature(registerNamespace) @@ -48,27 +52,24 @@ func New(config *Config, disco *xep0030.DiscoInfo, router *router.Router) *Regis return r } -// MatchesIQ returns whether or not an IQ should be -// processed by the in-band registration module. +// MatchesIQ returns whether or not an IQ should be processed by the in-band registration module. func (x *Register) MatchesIQ(iq *xmpp.IQ) bool { return iq.Elements().ChildNamespace("query", registerNamespace) != nil } -// ProcessIQ processes an in-band registration IQ taking according actions over -// the associated stream. -func (x *Register) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes an in-band registration IQ taking according actions over the associated stream. +func (x *Register) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - if stm := x.router.UserStream(iq.FromJID()); stm != nil { - x.processIQ(iq, stm) + if stm := x.router.LocalStream(iq.FromJID().Node(), iq.FromJID().Resource()); stm != nil { + x.processIQ(ctx, iq, stm) } }) } -// ProcessIQWithStream processes an in-band registration IQ taking according -// actions over a referenced stream. -func (x *Register) ProcessIQWithStream(iq *xmpp.IQ, stm stream.C2S) { +// ProcessIQWithStream processes an in-band registration IQ taking according actions over a referenced stream. +func (x *Register) ProcessIQWithStream(ctx context.Context, iq *xmpp.IQ, stm stream.C2S) { x.runQueue.Run(func() { - x.processIQ(iq, stm) + x.processIQ(ctx, iq, stm) }) } @@ -80,53 +81,54 @@ func (x *Register) Shutdown() error { return nil } -func (x *Register) processIQ(iq *xmpp.IQ, stm stream.C2S) { +func (x *Register) processIQ(ctx context.Context, iq *xmpp.IQ, stm stream.C2S) { if !x.isValidToJid(iq.ToJID(), stm) { - stm.SendElement(iq.ForbiddenError()) + stm.SendElement(ctx, iq.ForbiddenError()) return } q := iq.Elements().ChildNamespace("query", registerNamespace) if !stm.IsAuthenticated() { if iq.IsGet() { if !x.cfg.AllowRegistration { - stm.SendElement(iq.NotAllowedError()) + stm.SendElement(ctx, iq.NotAllowedError()) return } // ...send registration fields to requester entity... - x.sendRegistrationFields(iq, q, stm) + x.sendRegistrationFields(ctx, iq, q, stm) } else if iq.IsSet() { - if !stm.GetBool(xep077RegisteredCtxKey) { + registered, _ := stm.Value(xep077RegisteredCtxKey).(bool) + if !registered { // ...register a new user... - x.registerNewUser(iq, q, stm) + x.registerNewUser(ctx, iq, q, stm) } else { // return a stanza error if an entity attempts to register a second identity - stm.SendElement(iq.NotAcceptableError()) + stm.SendElement(ctx, iq.NotAcceptableError()) } } else { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) } } else if iq.IsSet() { if q.Elements().Child("remove") != nil { // remove user - x.cancelRegistration(iq, q, stm) + x.cancelRegistration(ctx, iq, q, stm) } else { user := q.Elements().Child("username") password := q.Elements().Child("password") if user != nil && password != nil { // change password - x.changePassword(password.Text(), user.Text(), iq, stm) + x.changePassword(ctx, password.Text(), user.Text(), iq, stm) } else { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) } } } else { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) } } -func (x *Register) sendRegistrationFields(iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) { +func (x *Register) sendRegistrationFields(ctx context.Context, iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) { if query.Elements().Count() > 0 { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) return } result := iq.ResultIQ() @@ -134,24 +136,24 @@ func (x *Register) sendRegistrationFields(iq *xmpp.IQ, query xmpp.XElement, stm q.AppendElement(xmpp.NewElementName("username")) q.AppendElement(xmpp.NewElementName("password")) result.AppendElement(q) - stm.SendElement(result) + stm.SendElement(ctx, result) } -func (x *Register) registerNewUser(iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) { +func (x *Register) registerNewUser(ctx context.Context, iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) { userEl := query.Elements().Child("username") passwordEl := query.Elements().Child("password") if userEl == nil || passwordEl == nil || len(userEl.Text()) == 0 || len(passwordEl.Text()) == 0 { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) return } - exists, err := storage.UserExists(userEl.Text()) + exists, err := x.rep.UserExists(ctx, userEl.Text()) if err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } if exists { - stm.SendElement(iq.ConflictError()) + stm.SendElement(ctx, iq.ConflictError()) return } user := model.User{ @@ -159,65 +161,65 @@ func (x *Register) registerNewUser(iq *xmpp.IQ, query xmpp.XElement, stm stream. Password: passwordEl.Text(), LastPresence: xmpp.NewPresence(stm.JID(), stm.JID(), xmpp.UnavailableType), } - if err := storage.InsertOrUpdateUser(&user); err != nil { + if err := x.rep.UpsertUser(ctx, &user); err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } - stm.SendElement(iq.ResultIQ()) - stm.SetBool(xep077RegisteredCtxKey, true) // mark as registered + stm.SendElement(ctx, iq.ResultIQ()) + stm.SetValue(xep077RegisteredCtxKey, true) // mark as registered } -func (x *Register) cancelRegistration(iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) { +func (x *Register) cancelRegistration(ctx context.Context, iq *xmpp.IQ, query xmpp.XElement, stm stream.C2S) { if !x.cfg.AllowCancel { - stm.SendElement(iq.NotAllowedError()) + stm.SendElement(ctx, iq.NotAllowedError()) return } if query.Elements().Count() > 1 { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) return } - if err := storage.DeleteUser(stm.Username()); err != nil { + if err := x.rep.DeleteUser(ctx, stm.Username()); err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } - stm.SendElement(iq.ResultIQ()) + stm.SendElement(ctx, iq.ResultIQ()) } -func (x *Register) changePassword(password string, username string, iq *xmpp.IQ, stm stream.C2S) { +func (x *Register) changePassword(ctx context.Context, password string, username string, iq *xmpp.IQ, stm stream.C2S) { if !x.cfg.AllowChange { - stm.SendElement(iq.NotAllowedError()) + stm.SendElement(ctx, iq.NotAllowedError()) return } if username != stm.Username() { - stm.SendElement(iq.NotAllowedError()) + stm.SendElement(ctx, iq.NotAllowedError()) return } if !stm.IsSecured() { // channel isn't safe enough to enable a password change - stm.SendElement(iq.NotAuthorizedError()) + stm.SendElement(ctx, iq.NotAuthorizedError()) return } - user, err := storage.FetchUser(username) + user, err := x.rep.FetchUser(ctx, username) if err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } if user == nil { - stm.SendElement(iq.ResultIQ()) + stm.SendElement(ctx, iq.ResultIQ()) return } if user.Password != password { user.Password = password - if err := storage.InsertOrUpdateUser(user); err != nil { + if err := x.rep.UpsertUser(ctx, user); err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } } - stm.SendElement(iq.ResultIQ()) + stm.SendElement(ctx, iq.ResultIQ()) } func (x *Register) isValidToJid(j *jid.JID, stm stream.C2S) bool { diff --git a/module/xep0077/register_test.go b/module/xep0077/register_test.go index 456f2873f..ce3cd1f3e 100644 --- a/module/xep0077/register_test.go +++ b/module/xep0077/register_test.go @@ -6,13 +6,16 @@ package xep0077 import ( + "context" "crypto/tls" "testing" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/model" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -21,10 +24,12 @@ import ( ) func TestXEP0077_Matching(t *testing.T) { + r, s := setupTest("jackal.im") + j, _ := jid.New("ortuman", "jackal.im", "balcony", true) - x := New(&Config{}, nil, nil) - defer x.Shutdown() + x := New(&Config{}, nil, r, s) + defer func() { _ = x.Shutdown() }() // test MatchesIQ iq := xmpp.NewIQType(uuid.New(), xmpp.SetType) @@ -36,24 +41,23 @@ func TestXEP0077_Matching(t *testing.T) { } func TestXEP0077_InvalidToJID(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j1, _ := jid.New("romeo", "jackal.im", "balcony", true) j2, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm1 := stream.NewMockC2S(uuid.New(), j1) - r.Bind(stm1) + r.Bind(context.Background(), stm1) - x := New(&Config{}, nil, r) - defer x.Shutdown() + x := New(&Config{}, nil, r, s) + defer func() { _ = x.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.SetType) iq.SetFromJID(j1) iq.SetToJID(j2.ToBareJID()) stm1.SetAuthenticated(true) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm1.ReceiveElement() require.Equal(t, xmpp.ErrForbidden.Error(), elem.Error().Elements().All()[0].Name()) @@ -63,94 +67,91 @@ func TestXEP0077_InvalidToJID(t *testing.T) { } func TestXEP0077_NotAuthenticatedErrors(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + r.Bind(context.Background(), stm) - x := New(&Config{}, nil, r) - defer x.Shutdown() + x := New(&Config{}, nil, r, s) + defer func() { _ = x.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.ResultType) iq.SetFromJID(j) iq.SetToJID(j.ToBareJID()) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) iq.SetType(xmpp.GetType) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAllowed.Error(), elem.Error().Elements().All()[0].Name()) // allow registration... - x = New(&Config{AllowRegistration: true}, nil, r) - defer x.Shutdown() + x = New(&Config{AllowRegistration: true}, nil, r, s) + defer func() { _ = x.Shutdown() }() q := xmpp.NewElementNamespace("query", registerNamespace) q.AppendElement(xmpp.NewElementName("q2")) iq.AppendElement(q) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) q.ClearElements() iq.SetType(xmpp.SetType) - stm.SetBool(xep077RegisteredCtxKey, true) + stm.SetValue(xep077RegisteredCtxKey, true) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAcceptable.Error(), elem.Error().Elements().All()[0].Name()) } func TestXEP0077_AuthenticatedErrors(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") srvJid, _ := jid.New("", "jackal.im", "", true) j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + r.Bind(context.Background(), stm) stm.SetAuthenticated(true) - x := New(&Config{}, nil, r) - defer x.Shutdown() + x := New(&Config{}, nil, r, s) + defer func() { _ = x.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.ResultType) iq.SetFromJID(j) iq.SetToJID(j.ToBareJID()) iq.SetToJID(srvJid) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) iq.SetType(xmpp.SetType) iq.AppendElement(xmpp.NewElementNamespace("query", registerNamespace)) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) } func TestXEP0077_RegisterUser(t *testing.T) { - r, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") srvJid, _ := jid.New("", "jackal.im", "", true) j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + r.Bind(context.Background(), stm) - x := New(&Config{AllowRegistration: true}, nil, r) - defer x.Shutdown() + x := New(&Config{AllowRegistration: true}, nil, r, s) + defer func() { _ = x.Shutdown() }() iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq.SetFromJID(j) @@ -159,7 +160,7 @@ func TestXEP0077_RegisterUser(t *testing.T) { q := xmpp.NewElementNamespace("query", registerNamespace) iq.AppendElement(q) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) q2 := stm.ReceiveElement().Elements().ChildNamespace("query", registerNamespace) require.NotNil(t, q2.Elements().Child("username")) require.NotNil(t, q2.Elements().Child("password")) @@ -171,50 +172,49 @@ func TestXEP0077_RegisterUser(t *testing.T) { // empty fields iq.SetType(xmpp.SetType) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) // already existing user... - storage.InsertOrUpdateUser(&model.User{Username: "ortuman", Password: "1234"}) + _ = s.UpsertUser(context.Background(), &model.User{Username: "ortuman", Password: "1234"}) username.SetText("ortuman") password.SetText("5678") - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrConflict.Error(), elem.Error().Elements().All()[0].Name()) // storage error - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() username.SetText("juliet") - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ResultType, elem.Type()) - usr, _ := storage.FetchUser("ortuman") + usr, _ := s.FetchUser(context.Background(), "ortuman") require.NotNil(t, usr) } func TestXEP0077_CancelRegistration(t *testing.T) { - r, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") srvJid, _ := jid.New("", "jackal.im", "", true) j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S("abcd1234", j) - r.Bind(stm) + r.Bind(context.Background(), stm) stm.SetAuthenticated(true) - x := New(&Config{}, nil, r) - defer x.Shutdown() + x := New(&Config{}, nil, r, s) + defer func() { _ = x.Shutdown() }() - storage.InsertOrUpdateUser(&model.User{Username: "ortuman", Password: "1234"}) + _ = s.UpsertUser(context.Background(), &model.User{Username: "ortuman", Password: "1234"}) iq := xmpp.NewIQType(uuid.New(), xmpp.SetType) iq.SetFromJID(j) @@ -224,51 +224,50 @@ func TestXEP0077_CancelRegistration(t *testing.T) { q.AppendElement(xmpp.NewElementName("remove")) iq.AppendElement(q) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAllowed.Error(), elem.Error().Elements().All()[0].Name()) - x = New(&Config{AllowCancel: true}, nil, r) - defer x.Shutdown() + x = New(&Config{AllowCancel: true}, nil, r, s) + defer func() { _ = x.Shutdown() }() q.AppendElement(xmpp.NewElementName("remove2")) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) q.ClearElements() q.AppendElement(xmpp.NewElementName("remove")) // storage error - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ResultType, elem.Type()) - usr, _ := storage.FetchUser("ortuman") + usr, _ := s.FetchUser(context.Background(), "ortuman") require.Nil(t, usr) } func TestXEP0077_ChangePassword(t *testing.T) { - r, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, s := setupTest("jackal.im") srvJid, _ := jid.New("", "jackal.im", "", true) j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + r.Bind(context.Background(), stm) stm.SetAuthenticated(true) - x := New(&Config{}, nil, r) - defer x.Shutdown() + x := New(&Config{}, nil, r, s) + defer func() { _ = x.Shutdown() }() - storage.InsertOrUpdateUser(&model.User{Username: "ortuman", Password: "1234"}) + _ = s.UpsertUser(context.Background(), &model.User{Username: "ortuman", Password: "1234"}) iq := xmpp.NewIQType(uuid.New(), xmpp.SetType) iq.SetFromJID(j) @@ -283,19 +282,19 @@ func TestXEP0077_ChangePassword(t *testing.T) { q.AppendElement(password) iq.AppendElement(q) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAllowed.Error(), elem.Error().Elements().All()[0].Name()) - x = New(&Config{AllowChange: true}, nil, r) - defer x.Shutdown() + x = New(&Config{AllowChange: true}, nil, r, s) + defer func() { _ = x.Shutdown() }() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAllowed.Error(), elem.Error().Elements().All()[0].Name()) username.SetText("ortuman") - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrNotAuthorized.Error(), elem.Error().Elements().All()[0].Name()) @@ -303,28 +302,28 @@ func TestXEP0077_ChangePassword(t *testing.T) { stm.SetSecured(true) // storage error - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ResultType, elem.Type()) - usr, _ := storage.FetchUser("ortuman") + usr, _ := s.FetchUser(context.Background(), "ortuman") require.NotNil(t, usr) require.Equal(t, "5678", usr.Password) } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, *memorystorage.User) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + userRep := memorystorage.NewUser() + r, _ := router.New( + hosts, + c2srouter.New(userRep, memorystorage.NewBlockList()), + nil, + ) + return r, userRep } diff --git a/module/xep0092/version.go b/module/xep0092/version.go index 5091141f2..b4841d954 100644 --- a/module/xep0092/version.go +++ b/module/xep0092/version.go @@ -6,13 +6,14 @@ package xep0092 import ( + "context" "os/exec" "strings" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module/xep0030" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/version" "github.com/ortuman/jackal/xmpp" ) @@ -34,12 +35,12 @@ type Config struct { // Version represents a version module. type Version struct { cfg *Config - router *router.Router + router router.Router runQueue *runqueue.RunQueue } // New returns a version IQ handler module. -func New(config *Config, disco *xep0030.DiscoInfo, router *router.Router) *Version { +func New(config *Config, disco *xep0030.DiscoInfo, router router.Router) *Version { v := &Version{ cfg: config, router: router, @@ -51,17 +52,15 @@ func New(config *Config, disco *xep0030.DiscoInfo, router *router.Router) *Versi return v } -// MatchesIQ returns whether or not an IQ should be -// processed by the version module. +// MatchesIQ returns whether or not an IQ should be processed by the version module. func (x *Version) MatchesIQ(iq *xmpp.IQ) bool { return iq.IsGet() && iq.Elements().ChildNamespace("query", versionNamespace) != nil && iq.ToJID().IsServer() } -// ProcessIQ processes a version IQ taking according actions -// over the associated stream. -func (x *Version) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes a version IQ taking according actions over the associated stream. +func (x *Version) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - x.processIQ(iq) + x.processIQ(ctx, iq) }) } @@ -73,16 +72,16 @@ func (x *Version) Shutdown() error { return nil } -func (x *Version) processIQ(iq *xmpp.IQ) { +func (x *Version) processIQ(ctx context.Context, iq *xmpp.IQ) { q := iq.Elements().ChildNamespace("query", versionNamespace) if q == nil || q.Elements().Count() != 0 { - _ = x.router.Route(iq.BadRequestError()) + _ = x.router.Route(ctx, iq.BadRequestError()) return } - x.sendSoftwareVersion(iq) + x.sendSoftwareVersion(ctx, iq) } -func (x *Version) sendSoftwareVersion(iq *xmpp.IQ) { +func (x *Version) sendSoftwareVersion(ctx context.Context, iq *xmpp.IQ) { userJID := iq.FromJID() username := userJID.Node() resource := userJID.Resource() @@ -105,5 +104,5 @@ func (x *Version) sendSoftwareVersion(iq *xmpp.IQ) { query.AppendElement(os) } result.AppendElement(query) - _ = x.router.Route(result) + _ = x.router.Route(ctx, result) } diff --git a/module/xep0092/version_test.go b/module/xep0092/version_test.go index fe015473c..2b03b1165 100644 --- a/module/xep0092/version_test.go +++ b/module/xep0092/version_test.go @@ -6,10 +6,15 @@ package xep0092 import ( + "context" "crypto/tls" "testing" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/router" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/version" "github.com/ortuman/jackal/xmpp" @@ -19,19 +24,19 @@ import ( ) func TestXEP0092(t *testing.T) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: "jackal.im", Certificate: tls.Certificate{}}}, - }) + r := setupTest() srvJID, _ := jid.New("", "jackal.im", "", true) j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - r.Bind(stm) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) cfg := Config{} x := New(&cfg, nil, r) - defer x.Shutdown() + defer func() { _ = x.Shutdown() }() // test MatchesIQ iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) @@ -49,13 +54,13 @@ func TestXEP0092(t *testing.T) { require.True(t, x.MatchesIQ(iq)) qVer.AppendElement(xmpp.NewElementName("version")) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) // get version qVer.ClearElements() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() ver := elem.Elements().ChildNamespace("query", versionNamespace) require.Equal(t, "jackal", ver.Elements().Child("name").Text()) @@ -66,10 +71,20 @@ func TestXEP0092(t *testing.T) { cfg.ShowOS = true x = New(&cfg, nil, r) - defer x.Shutdown() + defer func() { _ = x.Shutdown() }() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() ver = elem.Elements().ChildNamespace("query", versionNamespace) require.Equal(t, osString, ver.Elements().Child("os").Text()) } + +func setupTest() router.Router { + hosts, _ := host.New([]host.Config{{Name: "jackal.im", Certificate: tls.Certificate{}}}) + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r +} diff --git a/module/xep0115/entity_caps.go b/module/xep0115/entity_caps.go new file mode 100644 index 000000000..12260a321 --- /dev/null +++ b/module/xep0115/entity_caps.go @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0115 + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/ortuman/jackal/log" + capsmodel "github.com/ortuman/jackal/model/capabilities" + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/util/runqueue" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/pborman/uuid" +) + +const ( + discoInfoNamespace = "http://jabber.org/protocol/disco#info" +) + +// EntityCaps represents global entity capabilities module +type EntityCaps struct { + allocationID string + runQueue *runqueue.RunQueue + router router.Router + presencesRep repository.Presences + mu sync.RWMutex + activeDiscoInfo map[string]bool +} + +// New returns a new presence hub instance. +func New(router router.Router, presencesRep repository.Presences, allocationID string) *EntityCaps { + return &EntityCaps{ + runQueue: runqueue.New("xep0115"), + router: router, + presencesRep: presencesRep, + allocationID: allocationID, + activeDiscoInfo: make(map[string]bool), + } +} + +// RegisterPresence keeps track of a new client presence, requesting capabilities when necessary. +func (x *EntityCaps) RegisterPresence(ctx context.Context, presence *xmpp.Presence) (alreadyRegistered bool, err error) { + fromJID := presence.FromJID() + + // check if caps were previously cached + if c := presence.Capabilities(); c != nil { + if err := x.registerCapabilities(ctx, c.Node, c.Ver, presence.FromJID()); err != nil { + return false, err + } + } + // store available presence + inserted, err := x.presencesRep.UpsertPresence(ctx, presence, fromJID, x.allocationID) + if err != nil { + return false, err + } + return inserted, nil +} + +// UnregisterPresence removes a presence from the hub. +func (x *EntityCaps) UnregisterPresence(ctx context.Context, jid *jid.JID) error { + return x.presencesRep.DeletePresence(ctx, jid) +} + +// MatchesIQ returns whether or not an IQ should be processed by the roster module. +func (x *EntityCaps) MatchesIQ(iq *xmpp.IQ) bool { + x.mu.RLock() + defer x.mu.RUnlock() + _, ok := x.activeDiscoInfo[iq.ID()] + return ok && iq.IsResult() +} + +// ProcessIQ processes a roster IQ taking according actions over the associated stream. +func (x *EntityCaps) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { + x.runQueue.Run(func() { + x.processIQ(ctx, iq) + }) +} + +// Shutdown shuts down blocking module. +func (x *EntityCaps) Shutdown() error { + c := make(chan struct{}) + x.runQueue.Stop(func() { close(c) }) + <-c + return nil +} + +// PresencesMatchingJID returns current online presences matching a given JID. +func (x *EntityCaps) PresencesMatchingJID(ctx context.Context, jid *jid.JID) ([]capsmodel.PresenceCaps, error) { + return x.presencesRep.FetchPresencesMatchingJID(ctx, jid) +} + +func (x *EntityCaps) registerCapabilities(ctx context.Context, node, ver string, jid *jid.JID) error { + caps, err := x.presencesRep.FetchCapabilities(ctx, node, ver) // try fetching from disk + if err != nil { + return err + } + if caps == nil { + x.requestCapabilities(ctx, node, ver, jid) // request capabilities + } + return nil +} + +func (x *EntityCaps) processIQ(ctx context.Context, iq *xmpp.IQ) { + caps := iq.Elements().ChildNamespace("query", discoInfoNamespace) + if caps == nil { + return + } + // process capabilities result + if err := x.processCapabilitiesIQ(ctx, caps); err != nil { + log.Warnf("%v", err) + } +} + +func (x *EntityCaps) requestCapabilities(ctx context.Context, node, ver string, userJID *jid.JID) { + srvJID, _ := jid.NewWithString(x.router.Hosts().DefaultHostName(), true) + + iqID := uuid.New() + x.mu.Lock() + x.activeDiscoInfo[iqID] = true + x.mu.Unlock() + + iq := xmpp.NewIQType(iqID, xmpp.GetType) + iq.SetFromJID(srvJID) + iq.SetToJID(userJID) + + query := xmpp.NewElementNamespace("query", discoInfoNamespace) + query.SetAttribute("node", node+"#"+ver) + iq.AppendElement(query) + + log.Infof("requesting capabilities... node: %s, ver: %s", node, ver) + + _ = x.router.Route(ctx, iq) +} + +func (x *EntityCaps) processCapabilitiesIQ(ctx context.Context, query xmpp.XElement) error { + var node, ver string + + nodeStr := query.Attributes().Get("node") + ss := strings.Split(nodeStr, "#") + if len(ss) != 2 { + return fmt.Errorf("xep0115: wrong node format: %s", nodeStr) + } + node = ss[0] + ver = ss[1] + + // retrieve and store features + log.Infof("storing capabilities... node: %s, ver: %s", node, ver) + + var features []string + featureElems := query.Elements().Children("feature") + for _, featureElem := range featureElems { + features = append(features, featureElem.Attributes().Get("var")) + } + caps := &capsmodel.Capabilities{ + Node: node, + Ver: ver, + Features: features, + } + return x.presencesRep.UpsertCapabilities(ctx, caps) +} diff --git a/module/xep0115/entity_caps_test.go b/module/xep0115/entity_caps_test.go new file mode 100644 index 000000000..8ad4c4f96 --- /dev/null +++ b/module/xep0115/entity_caps_test.go @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0115 + +/* +func TestEntityCaps_RegisterPresence(t *testing.T) { +} + +func TestEntityCaps_RequestCapabilities(t *testing.T) { + r, s := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + + // register presence + p := xmpp.NewPresence(j1, j1, xmpp.AvailableType) + c := xmpp.NewElementNamespace("c", "http://jabber.org/protocol/caps") + c.SetAttribute("hash", "sha-1") + c.SetAttribute("node", "http://code.google.com/p/exodus") + c.SetAttribute("ver", "QgayPKawpkPSDYmwT/WM94uAlu0=") + p.AppendElement(c) + + ph := New(r, s, "alloc-1234") + _, _ = ph.RegisterPresence(context.Background(), p) + + elem := stm1.ReceiveElement() + require.Equal(t, "iq", elem.Name()) + require.Equal(t, "jackal.im", elem.From()) + + queryElem := elem.Elements().Child("query") + require.NotNil(t, queryElem) + + require.Equal(t, "http://jabber.org/protocol/disco#info", queryElem.Namespace()) + require.Equal(t, "http://code.google.com/p/exodus#QgayPKawpkPSDYmwT/WM94uAlu0=", queryElem.Attributes().Get("node")) +} + +func TestEntityCaps_ProcessCapabilities(t *testing.T) { + r, s := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + iqID := uuid.New() + + iqRes := xmpp.NewIQType(iqID, xmpp.ResultType) + iqRes.SetFromJID(j1) + iqRes.SetToJID(j1.ToBareJID()) + + qElem := xmpp.NewElementNamespace("query", "http://jabber.org/protocol/disco#info") + qElem.SetAttribute("node", "http://code.google.com/p/exodus#QgayPKawpkPSDYmwT/WM94uAlu0=") + featureEl := xmpp.NewElementName("feature") + featureEl.SetAttribute("var", "cool+feature") + qElem.AppendElement(featureEl) + iqRes.AppendElement(qElem) + + ph := New(r, s, "alloc-1234") + ph.activeDiscoInfo.Store(iqID, true) + + ph.processIQ(context.Background(), iqRes) + + // check storage capabilities + caps, _ := s.FetchCapabilities(context.Background(), "http://code.google.com/p/exodus", "QgayPKawpkPSDYmwT/WM94uAlu0=") + require.NotNil(t, caps) + + require.Len(t, caps.Features, 1) + require.Equal(t, "cool+feature", caps.Features[0]) +} + +func setupTest(domain string) (router.Router, *memorystorage.Presences) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + s := memorystorage.NewPresences() + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r, s +} +*/ diff --git a/module/xep0163/access_checker.go b/module/xep0163/access_checker.go new file mode 100644 index 000000000..423a47497 --- /dev/null +++ b/module/xep0163/access_checker.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0163 + +import ( + "context" + "errors" + "fmt" + + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/xmpp/jid" +) + +var ( + errOutcastMember = errors.New("pep: outcast member") + errPresenceSubscriptionRequired = errors.New("pep: presence subscription required") + errNotInRosterGroup = errors.New("pep: not in roster group") + errNotOnWhiteList = errors.New("pep: not on whitelist") +) + +type accessChecker struct { + host string + nodeID string + accessModel string + rosterAllowedGroups []string + affiliation *pubsubmodel.Affiliation + rosterRep repository.Roster +} + +func (ac *accessChecker) checkAccess(ctx context.Context, j string) error { + aff := ac.affiliation + if aff != nil && aff.Affiliation == pubsubmodel.Outcast { + return errOutcastMember + } + switch ac.accessModel { + case pubsubmodel.Open: + return nil + + case pubsubmodel.Presence: + allowed, err := ac.checkPresenceAccess(ctx, j) + if err != nil { + return err + } + if !allowed { + return errPresenceSubscriptionRequired + } + + case pubsubmodel.Roster: + allowed, err := ac.checkRosterAccess(ctx, j) + if err != nil { + return err + } + if !allowed { + return errNotInRosterGroup + } + + case pubsubmodel.WhiteList: + allowed, err := ac.checkWhitelistAccess(j) + if err != nil { + return err + } + if !allowed { + return errNotOnWhiteList + } + + default: + return fmt.Errorf("pep: unrecognized access model: %s", ac.accessModel) + } + return nil +} + +func (ac *accessChecker) checkPresenceAccess(ctx context.Context, j string) (bool, error) { + userJID, _ := jid.NewWithString(ac.host, true) + contactJID, _ := jid.NewWithString(j, true) + + ri, err := ac.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.ToBareJID().String()) + if err != nil { + return false, err + } + allowed := ri != nil && (ri.Subscription == rostermodel.SubscriptionFrom || ri.Subscription == rostermodel.SubscriptionBoth) + return allowed, nil +} + +func (ac *accessChecker) checkRosterAccess(ctx context.Context, j string) (bool, error) { + userJID, _ := jid.NewWithString(ac.host, true) + contactJID, _ := jid.NewWithString(j, true) + + ri, err := ac.rosterRep.FetchRosterItem(ctx, userJID.Node(), contactJID.ToBareJID().String()) + if err != nil { + return false, err + } + if ri == nil { + return false, nil + } + for _, group := range ri.Groups { + for _, allowedGroup := range ac.rosterAllowedGroups { + if group == allowedGroup { + return true, nil + } + } + } + return false, nil +} + +func (ac *accessChecker) checkWhitelistAccess(j string) (bool, error) { + aff := ac.affiliation + if aff == nil || j != aff.JID { + return false, nil + } + return aff.Affiliation == pubsubmodel.Owner || aff.Affiliation == pubsubmodel.Member, nil +} diff --git a/module/xep0163/access_checker_test.go b/module/xep0163/access_checker_test.go new file mode 100644 index 000000000..601546cea --- /dev/null +++ b/module/xep0163/access_checker_test.go @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0163 + +import ( + "context" + "testing" + + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + rostermodel "github.com/ortuman/jackal/model/roster" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/stretchr/testify/require" +) + +func TestAccessChecker_Open(t *testing.T) { + ac := &accessChecker{ + host: "ortuman@jackal.im", + nodeID: "princely_musings", + accessModel: pubsubmodel.Open, + rosterRep: memorystorage.NewRoster(), + } + + err := ac.checkAccess(context.Background(), "noelia@jackal.im") + require.Nil(t, err) +} + +func TestAccessChecker_Outcast(t *testing.T) { + ac := &accessChecker{ + host: "ortuman@jackal.im", + nodeID: "princely_musings", + accessModel: pubsubmodel.Open, + affiliation: &pubsubmodel.Affiliation{JID: "noelia@jackal.im", Affiliation: pubsubmodel.Outcast}, + rosterRep: memorystorage.NewRoster(), + } + + err := ac.checkAccess(context.Background(), "noelia@jackal.im") + require.NotNil(t, err) + require.Equal(t, errOutcastMember, err) +} + +func TestAccessChecker_PresenceSubscription(t *testing.T) { + rosterRep := memorystorage.NewRoster() + ac := &accessChecker{ + host: "ortuman@jackal.im", + nodeID: "princely_musings", + accessModel: pubsubmodel.Presence, + rosterRep: rosterRep, + } + + err := ac.checkAccess(context.Background(), "noelia@jackal.im") + require.NotNil(t, err) + require.Equal(t, errPresenceSubscriptionRequired, err) + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Subscription: rostermodel.SubscriptionFrom, + }) + + err = ac.checkAccess(context.Background(), "noelia@jackal.im") + require.Nil(t, err) +} + +func TestAccessChecker_RosterGroup(t *testing.T) { + rosterRep := memorystorage.NewRoster() + ac := &accessChecker{ + host: "ortuman@jackal.im", + nodeID: "princely_musings", + rosterAllowedGroups: []string{"Family"}, + accessModel: pubsubmodel.Roster, + rosterRep: rosterRep, + } + + err := ac.checkAccess(context.Background(), "noelia@jackal.im") + require.NotNil(t, err) + require.Equal(t, errNotInRosterGroup, err) + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Groups: []string{"Family"}, + Subscription: rostermodel.SubscriptionFrom, + }) + + err = ac.checkAccess(context.Background(), "noelia@jackal.im") + require.Nil(t, err) +} + +func TestAccessChecker_Member(t *testing.T) { + ac := &accessChecker{ + host: "ortuman@jackal.im", + nodeID: "princely_musings", + accessModel: pubsubmodel.WhiteList, + affiliation: &pubsubmodel.Affiliation{JID: "noelia@jackal.im", Affiliation: pubsubmodel.Member}, + rosterRep: memorystorage.NewRoster(), + } + + err := ac.checkAccess(context.Background(), "noelia2@jackal.im") + require.NotNil(t, err) + require.Equal(t, errNotOnWhiteList, err) + + err = ac.checkAccess(context.Background(), "noelia@jackal.im") + require.Nil(t, err) +} diff --git a/module/xep0163/disco_provider.go b/module/xep0163/disco_provider.go new file mode 100644 index 000000000..e94c88b78 --- /dev/null +++ b/module/xep0163/disco_provider.go @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0163 + +import ( + "context" + + "github.com/ortuman/jackal/log" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/module/xep0030" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +var pepFeatures = []string{ + "http://jabber.org/protocol/pubsub#access-presence", + "http://jabber.org/protocol/pubsub#auto-create", + "http://jabber.org/protocol/pubsub#auto-subscribe", + "http://jabber.org/protocol/pubsub#config-node", + "http://jabber.org/protocol/pubsub#create-and-configure", + "http://jabber.org/protocol/pubsub#create-nodes", + "http://jabber.org/protocol/pubsub#filtered-notifications", + "http://jabber.org/protocol/pubsub#persistent-items", + "http://jabber.org/protocol/pubsub#publish", + "http://jabber.org/protocol/pubsub#retrieve-items", + "http://jabber.org/protocol/pubsub#subscribe", +} + +type discoInfoProvider struct { + rosterRep repository.Roster + pubSubRep repository.PubSub +} + +func (p *discoInfoProvider) Identities(_ context.Context, _, _ *jid.JID, node string) []xep0030.Identity { + var identities []xep0030.Identity + if len(node) > 0 { + identities = append(identities, xep0030.Identity{Type: "leaf", Category: "pubsub"}) + } else { + identities = append(identities, xep0030.Identity{Type: "collection", Category: "pubsub"}) + } + identities = append(identities, xep0030.Identity{Type: "pep", Category: "pubsub"}) + return identities +} + +func (p *discoInfoProvider) Features(_ context.Context, _, _ *jid.JID, _ string) ([]xep0030.Feature, *xmpp.StanzaError) { + return pepFeatures, nil +} + +func (p *discoInfoProvider) Form(_ context.Context, _, _ *jid.JID, _ string) (*xep0004.DataForm, *xmpp.StanzaError) { + return nil, nil +} + +func (p *discoInfoProvider) Items(ctx context.Context, toJID, fromJID *jid.JID, node string) ([]xep0030.Item, *xmpp.StanzaError) { + if !p.isSubscribedTo(ctx, toJID, fromJID) { + return nil, xmpp.ErrSubscriptionRequired + } + host := toJID.ToBareJID().String() + + if len(node) > 0 { + // return node items + return p.nodeItems(ctx, host, node) + } + // return host nodes + return p.hostNodes(ctx, host) +} + +func (p *discoInfoProvider) hostNodes(ctx context.Context, host string) ([]xep0030.Item, *xmpp.StanzaError) { + var items []xep0030.Item + + nodes, err := p.pubSubRep.FetchNodes(ctx, host) + if err != nil { + log.Error(err) + return nil, xmpp.ErrInternalServerError + } + for _, node := range nodes { + items = append(items, xep0030.Item{ + Jid: host, + Node: node.Name, + Name: node.Options.Title, + }) + } + return items, nil +} + +func (p *discoInfoProvider) nodeItems(ctx context.Context, host, node string) ([]xep0030.Item, *xmpp.StanzaError) { + var items []xep0030.Item + + n, err := p.pubSubRep.FetchNode(ctx, host, node) + if err != nil { + log.Error(err) + return nil, xmpp.ErrInternalServerError + } + if n == nil { + // does not exist + return nil, xmpp.ErrItemNotFound + } + nodeItems, err := p.pubSubRep.FetchNodeItems(ctx, host, node) + if err != nil { + log.Error(err) + return nil, xmpp.ErrInternalServerError + } + for _, nodeItem := range nodeItems { + items = append(items, xep0030.Item{ + Jid: nodeItem.Publisher, + Name: nodeItem.ID, + }) + } + return items, nil +} + +func (p *discoInfoProvider) isSubscribedTo(ctx context.Context, contact *jid.JID, userJID *jid.JID) bool { + if contact.MatchesWithOptions(userJID, jid.MatchesBare) { + return true + } + ri, err := p.rosterRep.FetchRosterItem(ctx, userJID.Node(), contact.ToBareJID().String()) + if err != nil { + log.Error(err) + return false + } + if ri == nil { + return false + } + return ri.Subscription == rostermodel.SubscriptionTo || ri.Subscription == rostermodel.SubscriptionBoth +} diff --git a/module/xep0163/disco_provider_test.go b/module/xep0163/disco_provider_test.go new file mode 100644 index 000000000..f1c6955e3 --- /dev/null +++ b/module/xep0163/disco_provider_test.go @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0163 + +import ( + "context" + "reflect" + "testing" + + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + rostermodel "github.com/ortuman/jackal/model/roster" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestDiscoInfoProvider_Identities(t *testing.T) { + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "yard", true) + + dp := &discoInfoProvider{} + + ids := dp.Identities(context.Background(), j1, j2, "") + require.Len(t, ids, 2) + + require.Equal(t, "collection", ids[0].Type) + require.Equal(t, "pubsub", ids[0].Category) + require.Equal(t, "pep", ids[1].Type) + require.Equal(t, "pubsub", ids[1].Category) + + ids = dp.Identities(context.Background(), j1, j2, "node") + require.Len(t, ids, 2) + + require.Equal(t, "leaf", ids[0].Type) + require.Equal(t, "pubsub", ids[0].Category) + require.Equal(t, "pep", ids[1].Type) + require.Equal(t, "pubsub", ids[1].Category) +} + +func TestDiscoInfoProvider_Items(t *testing.T) { + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "yard", true) + + pubSubRep := memorystorage.NewPubSub() + + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + rosterRep := memorystorage.NewRoster() + dp := &discoInfoProvider{ + rosterRep: rosterRep, + pubSubRep: pubSubRep, + } + + items, err := dp.Items(context.Background(), j1, j2, "") + require.Nil(t, items) + require.NotNil(t, err) + require.Equal(t, xmpp.ErrSubscriptionRequired, err) + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "noelia", + JID: "ortuman@jackal.im", + Subscription: rostermodel.SubscriptionTo, + }) + + items, err = dp.Items(context.Background(), j1, j2, "") + require.Nil(t, err) + require.Len(t, items, 1) + + require.Equal(t, "ortuman@jackal.im", items[0].Jid) + require.Equal(t, "princely_musings", items[0].Node) +} + +func TestDiscoInfoProvider_Features(t *testing.T) { + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "yard", true) + + dp := &discoInfoProvider{} + + features, _ := dp.Features(context.Background(), j1, j2, "") + require.True(t, reflect.DeepEqual(features, pepFeatures)) + + features, _ = dp.Features(context.Background(), j1, j2, "node") + require.True(t, reflect.DeepEqual(features, pepFeatures)) +} + +func TestDiscoInfoProvider_Form(t *testing.T) { + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "yard", true) + + dp := &discoInfoProvider{} + + features, _ := dp.Features(context.Background(), j1, j2, "") + require.True(t, reflect.DeepEqual(features, pepFeatures)) + + form, _ := dp.Form(context.Background(), j1, j2, "") + require.Nil(t, form) + + form, _ = dp.Form(context.Background(), j1, j2, "node") + require.Nil(t, form) +} diff --git a/module/xep0163/pep.go b/module/xep0163/pep.go new file mode 100755 index 000000000..f3cbaef51 --- /dev/null +++ b/module/xep0163/pep.go @@ -0,0 +1,1256 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0163 + +import ( + "context" + "crypto/sha256" + "fmt" + + "github.com/google/uuid" + "github.com/ortuman/jackal/log" + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/module/xep0030" + "github.com/ortuman/jackal/module/xep0115" + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/util/runqueue" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +const ( + pubSubNamespace = "http://jabber.org/protocol/pubsub" + pubSubOwnerNamespace = "http://jabber.org/protocol/pubsub#owner" + pubSubEventNamespace = "http://jabber.org/protocol/pubsub#event" + + pubSubErrorNamespace = "http://jabber.org/protocol/pubsub#errors" +) + +var defaultNodeOptions = pubsubmodel.Options{ + DeliverNotifications: true, + DeliverPayloads: true, + PersistItems: true, + AccessModel: pubsubmodel.Presence, + MaxItems: 1, + SendLastPublishedItem: pubsubmodel.OnSubAndPresence, + NotificationType: xmpp.HeadlineType, +} + +type commandOptions struct { + allowedAffiliations []string + includeAffiliations bool + includeSubscriptions bool + checkAccess bool + failOnNotFound bool +} + +type commandContext struct { + host string + nodeID string + isAccountOwner bool + node *pubsubmodel.Node + affiliations []pubsubmodel.Affiliation + subscriptions []pubsubmodel.Subscription + accessChecker *accessChecker +} + +// Pep represents a Personal Eventing Protocol module. +type Pep struct { + runQueue *runqueue.RunQueue + router router.Router + rosterRep repository.Roster + pubSubRep repository.PubSub + disco *xep0030.DiscoInfo + entityCaps *xep0115.EntityCaps + hosts []string +} + +// New returns a PEP command IQ handler module. +func New(disco *xep0030.DiscoInfo, presenceHub *xep0115.EntityCaps, router router.Router, rosterRep repository.Roster, pubSubRep repository.PubSub) *Pep { + p := &Pep{ + runQueue: runqueue.New("xep0163"), + rosterRep: rosterRep, + pubSubRep: pubSubRep, + router: router, + disco: disco, + entityCaps: presenceHub, + } + // register account identity and features + if disco != nil { + for _, feature := range pepFeatures { + disco.RegisterAccountFeature(feature) + } + } + // register disco items + p.registerDiscoItems(context.Background()) + return p +} + +// MatchesIQ returns whether or not an IQ should be processed by the PEP module. +func (x *Pep) MatchesIQ(iq *xmpp.IQ) bool { + pubSub := iq.Elements().Child("pubsub") + if pubSub == nil { + return false + } + switch pubSub.Namespace() { + case pubSubNamespace, pubSubOwnerNamespace: + return true + } + return false +} + +// ProcessIQ processes a version IQ taking according actions over the associated stream +func (x *Pep) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { + x.runQueue.Run(func() { + x.processIQ(ctx, iq) + }) +} + +// SubscribeToAll subscribes a jid to all host nodes +func (x *Pep) SubscribeToAll(ctx context.Context, host string, jid *jid.JID) { + x.runQueue.Run(func() { + if err := x.subscribeToAll(ctx, host, jid); err != nil { + log.Error(err) + } + }) +} + +// UnsubscribeFromAll unsubscribes a jid from all host nodes +func (x *Pep) UnsubscribeFromAll(ctx context.Context, host string, jid *jid.JID) { + x.runQueue.Run(func() { + if err := x.unsubscribeFromAll(ctx, host, jid); err != nil { + log.Error(err) + } + }) +} + +// DeliverLastItems delivers last items from all those nodes to which the jid is subscribed +func (x *Pep) DeliverLastItems(ctx context.Context, jid *jid.JID) { + x.runQueue.Run(func() { + if err := x.deliverLastItems(ctx, jid); err != nil { + log.Error(err) + } + }) +} + +// Shutdown shuts down version module. +func (x *Pep) Shutdown() error { + c := make(chan struct{}) + x.runQueue.Stop(func() { close(c) }) + <-c + return nil +} + +func (x *Pep) processIQ(ctx context.Context, iq *xmpp.IQ) { + pubSub := iq.Elements().Child("pubsub") + switch pubSub.Namespace() { + case pubSubNamespace: + x.processRequest(ctx, iq, pubSub) + case pubSubOwnerNamespace: + x.processOwnerRequest(ctx, iq, pubSub) + } +} + +func (x *Pep) registerDiscoItems(ctx context.Context) { + if x.disco == nil { + return // nothing to do here + } + if err := x.registerDiscoItemHandlers(ctx); err != nil { + log.Warnf("pep: failed to register disco item handlers: %v", err) + } +} + +func (x *Pep) registerDiscoItemHandlers(ctx context.Context) error { + // unregister previous handlers + for _, h := range x.hosts { + x.disco.UnregisterProvider(h) + } + // register current ones + hosts, err := x.pubSubRep.FetchHosts(ctx) + if err != nil { + return err + } + for _, host := range hosts { + x.disco.RegisterProvider(host, &discoInfoProvider{ + rosterRep: x.rosterRep, + pubSubRep: x.pubSubRep, + }) + } + x.hosts = hosts + return nil +} + +func (x *Pep) subscribeToAll(ctx context.Context, host string, subJID *jid.JID) error { + nodes, err := x.pubSubRep.FetchNodes(ctx, host) + if err != nil { + return err + } + for _, node := range nodes { + if err := x.subscribeTo(ctx, &node, subJID); err != nil { + return err + } + } + return nil +} + +func (x *Pep) subscribeTo(ctx context.Context, n *pubsubmodel.Node, subJID *jid.JID) error { + // upsert subscription + subID := subscriptionID(subJID.ToBareJID().String(), n.Host, n.Name) + sub := pubsubmodel.Subscription{ + SubID: subID, + JID: subJID.ToBareJID().String(), + Subscription: pubsubmodel.Subscribed, + } + if err := x.pubSubRep.UpsertNodeSubscription(ctx, &sub, n.Host, n.Name); err != nil { + return err + } + log.Infof("pep: subscription created (host: %s, node_id: %s, jid: %s)", n.Host, n.Name, subJID) + + // notify subscription update + affiliations, err := x.pubSubRep.FetchNodeAffiliations(ctx, n.Host, n.Name) + if err != nil { + return err + } + subscriptionElem := xmpp.NewElementName("subscription") + subscriptionElem.SetAttribute("node", n.Name) + subscriptionElem.SetAttribute("jid", subJID.ToBareJID().String()) + subscriptionElem.SetAttribute("subid", subID) + subscriptionElem.SetAttribute("subscription", pubsubmodel.Subscribed) + + if n.Options.DeliverNotifications && n.Options.NotifySub { + x.notifyOwners(ctx, subscriptionElem, affiliations, n.Host, n.Options.NotificationType) + } + // send last node item + switch n.Options.SendLastPublishedItem { + case pubsubmodel.OnSub, pubsubmodel.OnSubAndPresence: + var subAff *pubsubmodel.Affiliation + for _, aff := range affiliations { + if aff.JID == subJID.ToBareJID().String() { + subAff = &aff + break + } + } + accessChecker := &accessChecker{ + host: n.Host, + nodeID: n.Name, + accessModel: n.Options.AccessModel, + rosterAllowedGroups: n.Options.RosterGroupsAllowed, + affiliation: subAff, + rosterRep: x.rosterRep, + } + return x.sendLastPublishedItem(ctx, subJID, accessChecker, n.Host, n.Name, n.Options.NotificationType) + } + return nil +} + +func (x *Pep) unsubscribeFromAll(ctx context.Context, host string, subJID *jid.JID) error { + nodes, err := x.pubSubRep.FetchNodes(ctx, host) + if err != nil { + return err + } + for _, n := range nodes { + if err := x.pubSubRep.DeleteNodeSubscription(ctx, subJID.ToBareJID().String(), host, n.Name); err != nil { + return err + } + log.Infof("pep: subscription removed (host: %s, node_id: %s, jid: %s)", host, n.Name, subJID.ToBareJID().String()) + } + return nil +} + +func (x *Pep) deliverLastItems(ctx context.Context, jid *jid.JID) error { + nodes, err := x.pubSubRep.FetchSubscribedNodes(ctx, jid.ToBareJID().String()) + if err != nil { + return err + } + for _, node := range nodes { + if node.Options.SendLastPublishedItem != pubsubmodel.OnSubAndPresence { + continue + } + aff, err := x.pubSubRep.FetchNodeAffiliation(ctx, node.Host, node.Name, jid.ToBareJID().String()) + if err != nil { + return err + } + accessChecker := &accessChecker{ + host: node.Host, + nodeID: node.Name, + accessModel: node.Options.AccessModel, + rosterAllowedGroups: node.Options.RosterGroupsAllowed, + affiliation: aff, + rosterRep: x.rosterRep, + } + if err := x.sendLastPublishedItem(ctx, jid, accessChecker, node.Host, node.Name, node.Options.NotificationType); err != nil { + return err + } + log.Infof("pep: delivered last item: %s (node: %s, host: %s)", jid.String(), node.Host, node.Name) + } + return nil +} + +func (x *Pep) processRequest(ctx context.Context, iq *xmpp.IQ, pubSubEl xmpp.XElement) { + // Create node + if cmdEl := pubSubEl.Elements().Child("create"); cmdEl != nil && iq.IsSet() { + x.withCommandContext(ctx, commandOptions{}, cmdEl, iq, func(cmdCtx *commandContext) { + x.create(ctx, cmdCtx, pubSubEl, iq) + }) + return + } + // Publish + if cmdEl := pubSubEl.Elements().Child("publish"); cmdEl != nil && iq.IsSet() { + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner, pubsubmodel.Member}, + includeSubscriptions: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.publish(ctx, cmdCtx, cmdEl, iq) + }) + return + } + // Subscribe + if cmdEl := pubSubEl.Elements().Child("subscribe"); cmdEl != nil && iq.IsSet() { + opts := commandOptions{ + includeAffiliations: true, + checkAccess: true, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.subscribe(ctx, cmdCtx, cmdEl, iq) + }) + return + } + // Unsubscribe + if cmdEl := pubSubEl.Elements().Child("unsubscribe"); cmdEl != nil && iq.IsSet() { + opts := commandOptions{ + includeAffiliations: true, + includeSubscriptions: true, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.unsubscribe(ctx, cmdCtx, cmdEl, iq) + }) + return + } + // Retrieve items + if cmdEl := pubSubEl.Elements().Child("items"); cmdEl != nil && iq.IsGet() { + opts := commandOptions{ + includeSubscriptions: true, + checkAccess: true, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.retrieveItems(ctx, cmdCtx, cmdEl, iq) + }) + return + } + + _ = x.router.Route(ctx, iq.ServiceUnavailableError()) +} + +func (x *Pep) processOwnerRequest(ctx context.Context, iq *xmpp.IQ, pubSub xmpp.XElement) { + // Configure node + if cmdEl := pubSub.Elements().Child("configure"); cmdEl != nil { + if iq.IsGet() { + // send configuration form + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner}, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.sendConfigurationForm(ctx, cmdCtx, iq) + }) + } else if iq.IsSet() { + // update node configuration + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner}, + includeSubscriptions: true, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.configure(ctx, cmdCtx, cmdEl, iq) + }) + } else { + _ = x.router.Route(ctx, iq.ServiceUnavailableError()) + } + return + } + // Manage affiliations + if cmdEl := pubSub.Elements().Child("affiliations"); cmdEl != nil { + if iq.IsGet() { + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner}, + includeAffiliations: true, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.retrieveAffiliations(ctx, cmdCtx, iq) + }) + } else if iq.IsSet() { + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner}, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.updateAffiliations(ctx, cmdCtx, cmdEl, iq) + }) + } else { + _ = x.router.Route(ctx, iq.ServiceUnavailableError()) + } + return + } + // Manage subscriptions + if cmdEl := pubSub.Elements().Child("subscriptions"); cmdEl != nil { + if iq.IsGet() { + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner}, + includeSubscriptions: true, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.retrieveSubscriptions(ctx, cmdCtx, iq) + }) + } else if iq.IsSet() { + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner}, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.updateSubscriptions(ctx, cmdCtx, cmdEl, iq) + }) + } else { + _ = x.router.Route(ctx, iq.ServiceUnavailableError()) + } + return + } + // Delete node + if cmdEl := pubSub.Elements().Child("delete"); cmdEl != nil && iq.IsSet() { + opts := commandOptions{ + allowedAffiliations: []string{pubsubmodel.Owner}, + includeSubscriptions: true, + failOnNotFound: true, + } + x.withCommandContext(ctx, opts, cmdEl, iq, func(cmdCtx *commandContext) { + x.delete(ctx, cmdCtx, iq) + }) + return + } + + _ = x.router.Route(ctx, iq.FeatureNotImplementedError()) +} + +func (x *Pep) create(ctx context.Context, cmdCtx *commandContext, pubSubEl xmpp.XElement, iq *xmpp.IQ) { + if cmdCtx.node != nil { + _ = x.router.Route(ctx, iq.ConflictError()) + return + } + node := &pubsubmodel.Node{ + Host: cmdCtx.host, + Name: cmdCtx.nodeID, + } + if configEl := pubSubEl.Elements().Child("configure"); configEl != nil { + form, err := xep0004.NewFormFromElement(configEl) + if err != nil { + _ = x.router.Route(ctx, iq.BadRequestError()) + return + } + opts, err := pubsubmodel.NewOptionsFromSubmitForm(form) + if err != nil { + _ = x.router.Route(ctx, iq.BadRequestError()) + return + } + node.Options = *opts + } else { + // apply default configuration + node.Options = defaultNodeOptions + } + if err := x.createNode(ctx, node); err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + log.Infof("pep: created node (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + _ = x.router.Route(ctx, iq.ResultIQ()) +} + +func (x *Pep) sendConfigurationForm(ctx context.Context, cmdCtx *commandContext, iq *xmpp.IQ) { + // compose config form response + configureNode := xmpp.NewElementName("configure") + configureNode.SetAttribute("node", cmdCtx.nodeID) + + rosterGroups, err := x.rosterRep.FetchRosterGroups(ctx, iq.ToJID().Node()) + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + + configureNode.AppendElement(cmdCtx.node.Options.Form(rosterGroups).Element()) + + pubSubNode := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + pubSubNode.AppendElement(configureNode) + + res := iq.ResultIQ() + res.AppendElement(pubSubNode) + + log.Infof("pep: sent configuration form (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + _ = x.router.Route(ctx, res) +} + +func (x *Pep) configure(ctx context.Context, cmdCtx *commandContext, cmdElem xmpp.XElement, iq *xmpp.IQ) { + formEl := cmdElem.Elements().ChildNamespace("x", xep0004.FormNamespace) + if formEl == nil { + _ = x.router.Route(ctx, iq.NotAcceptableError()) + return + } + configForm, err := xep0004.NewFormFromElement(formEl) + if err != nil { + _ = x.router.Route(ctx, iq.NotAcceptableError()) + return + } + nodeOpts, err := pubsubmodel.NewOptionsFromSubmitForm(configForm) + if err != nil { + _ = x.router.Route(ctx, iq.NotAcceptableError()) + return + } + cmdCtx.node.Options = *nodeOpts + + // update node config + if err := x.pubSubRep.UpsertNode(ctx, cmdCtx.node); err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + // notify config update + opts := cmdCtx.node.Options + + if opts.DeliverNotifications && opts.NotifyConfig { + configElem := xmpp.NewElementName("configuration") + configElem.SetAttribute("node", cmdCtx.nodeID) + + if opts.DeliverPayloads { + configElem.AppendElement(opts.ResultForm().Element()) + } + x.notifySubscribers( + ctx, + configElem, + cmdCtx.subscriptions, + cmdCtx.accessChecker, + cmdCtx.host, + cmdCtx.nodeID, + opts.NotificationType) + } + log.Infof("pep: node configuration updated (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + _ = x.router.Route(ctx, iq.ResultIQ()) +} + +func (x *Pep) delete(ctx context.Context, cmdCtx *commandContext, iq *xmpp.IQ) { + // delete node + if err := x.pubSubRep.DeleteNode(ctx, cmdCtx.host, cmdCtx.nodeID); err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + // notify delete + opts := cmdCtx.node.Options + + if opts.DeliverNotifications && opts.NotifyDelete { + deleteElem := xmpp.NewElementName("delete") + deleteElem.SetAttribute("node", cmdCtx.nodeID) + + x.notifySubscribers( + ctx, + deleteElem, + cmdCtx.subscriptions, + cmdCtx.accessChecker, + cmdCtx.host, + cmdCtx.nodeID, + opts.NotificationType) + } + log.Infof("pep: deleted node (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + x.registerDiscoItems(ctx) + _ = x.router.Route(ctx, iq.ResultIQ()) +} + +func (x *Pep) subscribe(ctx context.Context, cmdCtx *commandContext, cmdEl xmpp.XElement, iq *xmpp.IQ) { + // validate JID portion + subJID := cmdEl.Attributes().Get("jid") + if subJID != iq.FromJID().ToBareJID().String() { + _ = x.router.Route(ctx, invalidJIDError(iq)) + return + } + // create subscription + subID := subscriptionID(subJID, cmdCtx.host, cmdCtx.nodeID) + + sub := pubsubmodel.Subscription{ + SubID: subID, + JID: subJID, + Subscription: pubsubmodel.Subscribed, + } + err := x.pubSubRep.UpsertNodeSubscription(ctx, &sub, cmdCtx.host, cmdCtx.nodeID) + + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + log.Infof("pep: subscription created (host: %s, node_id: %s, jid: %s)", cmdCtx.host, cmdCtx.nodeID, subJID) + + // notify subscription update + subscriptionElem := xmpp.NewElementName("subscription") + subscriptionElem.SetAttribute("node", cmdCtx.nodeID) + subscriptionElem.SetAttribute("jid", subJID) + subscriptionElem.SetAttribute("subid", subID) + subscriptionElem.SetAttribute("subscription", pubsubmodel.Subscribed) + + opts := cmdCtx.node.Options + if opts.DeliverNotifications && opts.NotifySub { + x.notifyOwners(ctx, subscriptionElem, cmdCtx.affiliations, cmdCtx.host, opts.NotificationType) + } + // send last node item + switch opts.SendLastPublishedItem { + case pubsubmodel.OnSub, pubsubmodel.OnSubAndPresence: + subscriberJID, _ := jid.NewWithString(sub.JID, true) + err := x.sendLastPublishedItem(ctx, subscriberJID, cmdCtx.accessChecker, cmdCtx.host, cmdCtx.nodeID, opts.NotificationType) + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + } + + // compose response + iqRes := iq.ResultIQ() + pubSubElem := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + pubSubElem.AppendElement(subscriptionElem) + iqRes.AppendElement(pubSubElem) + + _ = x.router.Route(ctx, iqRes) +} + +func (x *Pep) unsubscribe(ctx context.Context, cmdCtx *commandContext, cmdEl xmpp.XElement, iq *xmpp.IQ) { + subJID := cmdEl.Attributes().Get("jid") + if subJID != iq.FromJID().ToBareJID().String() { + _ = x.router.Route(ctx, iq.ForbiddenError()) + return + } + var subscription *pubsubmodel.Subscription + for _, sub := range cmdCtx.subscriptions { + if sub.JID == subJID { + subscription = &sub + break + } + } + if subscription == nil { + _ = x.router.Route(ctx, notSubscribedError(iq)) + return + } + // delete subscription + if err := x.pubSubRep.DeleteNodeSubscription(ctx, subJID, cmdCtx.host, cmdCtx.nodeID); err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + log.Infof("pep: subscription removed (host: %s, node_id: %s, jid: %s)", cmdCtx.host, cmdCtx.nodeID, subJID) + + // notify subscription update + subscriptionElem := xmpp.NewElementName("subscription") + subscriptionElem.SetAttribute("node", cmdCtx.nodeID) + subscriptionElem.SetAttribute("jid", subJID) + subscriptionElem.SetAttribute("subid", subscription.SubID) + subscriptionElem.SetAttribute("subscription", pubsubmodel.None) + + opts := cmdCtx.node.Options + if opts.DeliverNotifications && opts.NotifySub { + x.notifyOwners(ctx, subscriptionElem, cmdCtx.affiliations, cmdCtx.host, opts.NotificationType) + } + + // compose response + iqRes := iq.ResultIQ() + pubSubElem := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + pubSubElem.AppendElement(subscriptionElem) + iqRes.AppendElement(pubSubElem) + + _ = x.router.Route(ctx, iqRes) +} + +func (x *Pep) publish(ctx context.Context, cmdCtx *commandContext, cmdEl xmpp.XElement, iq *xmpp.IQ) { + itemEl := cmdEl.Elements().Child("item") + if itemEl == nil || len(itemEl.Elements().All()) != 1 { + _ = x.router.Route(ctx, invalidPayloadError(iq)) + return + } + itemID := itemEl.Attributes().Get("id") + if len(itemID) == 0 { + // generate unique item identifier + itemID = uuid.New().String() + } + // auto create node + if cmdCtx.node == nil { + if !cmdCtx.isAccountOwner { + _ = x.router.Route(ctx, iq.ForbiddenError()) + return + } + cmdCtx.node = &pubsubmodel.Node{ + Host: cmdCtx.host, + Name: cmdCtx.nodeID, + Options: defaultNodeOptions, + } + cmdCtx.subscriptions = []pubsubmodel.Subscription{{ + JID: cmdCtx.host, + SubID: subscriptionID(cmdCtx.host, cmdCtx.host, cmdCtx.nodeID), + Subscription: pubsubmodel.Subscribed, + }} + if err := x.createNode(ctx, cmdCtx.node); err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + } + // persist node item + opts := cmdCtx.node.Options + if opts.PersistItems { + err := x.pubSubRep.UpsertNodeItem(ctx, &pubsubmodel.Item{ + ID: itemID, + Publisher: iq.FromJID().ToBareJID().String(), + Payload: itemEl.Elements().All()[0], + }, cmdCtx.host, cmdCtx.nodeID, int(opts.MaxItems)) + + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + } + log.Infof("pep: published item (host: %s, node_id: %s, item_id: %s)", cmdCtx.host, cmdCtx.nodeID, itemID) + + // notify published item + itemsElem := xmpp.NewElementName("items") + itemsElem.SetAttribute("node", cmdCtx.nodeID) + + itemElem := xmpp.NewElementName("item") + itemElem.SetAttribute("id", itemID) + if opts.DeliverPayloads || !opts.PersistItems { + itemElem.AppendElement(itemEl.Elements().All()[0]) + } + itemsElem.AppendElement(itemElem) + + x.notifySubscribers( + ctx, + itemsElem, + cmdCtx.subscriptions, + cmdCtx.accessChecker, + cmdCtx.host, + cmdCtx.nodeID, + cmdCtx.node.Options.NotificationType) + + // compose response + publishElem := xmpp.NewElementName("publish") + publishElem.SetAttribute("node", cmdCtx.nodeID) + resItemElem := xmpp.NewElementName("item") + resItemElem.SetAttribute("id", itemID) + publishElem.AppendElement(resItemElem) + + iqRes := iq.ResultIQ() + pubSubElem := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + pubSubElem.AppendElement(publishElem) + iqRes.AppendElement(pubSubElem) + + _ = x.router.Route(ctx, iqRes) +} + +func (x *Pep) retrieveItems(ctx context.Context, cmdCtx *commandContext, cmdEl xmpp.XElement, iq *xmpp.IQ) { + var itemIDs []string + + itemElems := cmdEl.Elements().Children("item") + if len(itemElems) > 0 { + for _, itemEl := range itemElems { + itemID := itemEl.Attributes().Get("id") + if len(itemID) == 0 { + continue + } + itemIDs = append(itemIDs, itemID) + } + } + // retrieve node items + var items []pubsubmodel.Item + var err error + + if len(itemIDs) > 0 { + items, err = x.pubSubRep.FetchNodeItemsWithIDs(ctx, cmdCtx.host, cmdCtx.nodeID, itemIDs) + } else { + items, err = x.pubSubRep.FetchNodeItems(ctx, cmdCtx.host, cmdCtx.nodeID) + } + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + log.Infof("pep: retrieved items (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + // compose response + iqRes := iq.ResultIQ() + pubSubElem := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + itemsElem := xmpp.NewElementName("items") + itemsElem.SetAttribute("node", cmdCtx.nodeID) + + for _, itm := range items { + itemElem := xmpp.NewElementName("item") + itemElem.SetAttribute("id", itm.ID) + itemElem.AppendElement(itm.Payload) + + itemsElem.AppendElement(itemElem) + } + pubSubElem.AppendElement(itemsElem) + iqRes.AppendElement(pubSubElem) + + _ = x.router.Route(ctx, iqRes) +} + +func (x *Pep) retrieveAffiliations(ctx context.Context, cmdCtx *commandContext, iq *xmpp.IQ) { + affiliationsElem := xmpp.NewElementName("affiliations") + affiliationsElem.SetAttribute("node", cmdCtx.nodeID) + + for _, aff := range cmdCtx.affiliations { + affElem := xmpp.NewElementName("affiliation") + affElem.SetAttribute("jid", aff.JID) + affElem.SetAttribute("affiliation", aff.Affiliation) + + affiliationsElem.AppendElement(affElem) + } + log.Infof("pep: retrieved affiliations (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + // compose response + iqRes := iq.ResultIQ() + pubSubElem := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + pubSubElem.AppendElement(affiliationsElem) + iqRes.AppendElement(pubSubElem) + + _ = x.router.Route(ctx, iqRes) +} + +func (x *Pep) updateAffiliations(ctx context.Context, cmdCtx *commandContext, cmdElem xmpp.XElement, iq *xmpp.IQ) { + // update affiliations + for _, affElem := range cmdElem.Elements().Children("affiliation") { + var aff pubsubmodel.Affiliation + aff.JID = affElem.Attributes().Get("jid") + aff.Affiliation = affElem.Attributes().Get("affiliation") + + if aff.JID == cmdCtx.host { + // ignore node owner affiliation update + continue + } + var err error + switch aff.Affiliation { + case pubsubmodel.Owner, pubsubmodel.Member, pubsubmodel.Publisher, pubsubmodel.Outcast: + err = x.pubSubRep.UpsertNodeAffiliation(ctx, &aff, cmdCtx.host, cmdCtx.nodeID) + case pubsubmodel.None: + err = x.pubSubRep.DeleteNodeAffiliation(ctx, aff.JID, cmdCtx.host, cmdCtx.nodeID) + default: + _ = x.router.Route(ctx, iq.BadRequestError()) + return + } + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + } + log.Infof("pep: modified affiliations (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + _ = x.router.Route(ctx, iq.ResultIQ()) +} + +func (x *Pep) retrieveSubscriptions(ctx context.Context, cmdCtx *commandContext, iq *xmpp.IQ) { + subscriptionsElem := xmpp.NewElementName("subscriptions") + subscriptionsElem.SetAttribute("node", cmdCtx.nodeID) + + for _, sub := range cmdCtx.subscriptions { + subElem := xmpp.NewElementName("subscription") + subElem.SetAttribute("subid", sub.SubID) + subElem.SetAttribute("jid", sub.JID) + subElem.SetAttribute("subscription", sub.Subscription) + + subscriptionsElem.AppendElement(subElem) + } + log.Infof("pep: retrieved subscriptions (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + // compose response + iqRes := iq.ResultIQ() + pubSubElem := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + pubSubElem.AppendElement(subscriptionsElem) + iqRes.AppendElement(pubSubElem) + + _ = x.router.Route(ctx, iqRes) +} + +func (x *Pep) updateSubscriptions(ctx context.Context, cmdCtx *commandContext, cmdElem xmpp.XElement, iq *xmpp.IQ) { + // update subscriptions + for _, subElem := range cmdElem.Elements().Children("subscription") { + var sub pubsubmodel.Subscription + sub.SubID = subElem.Attributes().Get("subid") + sub.JID = subElem.Attributes().Get("jid") + sub.Subscription = subElem.Attributes().Get("subscription") + + if sub.JID == cmdCtx.host { + // ignore node owner subscription update + continue + } + var err error + switch sub.Subscription { + case pubsubmodel.Subscribed: + err = x.pubSubRep.UpsertNodeSubscription(ctx, &sub, cmdCtx.host, cmdCtx.nodeID) + case pubsubmodel.None: + err = x.pubSubRep.DeleteNodeSubscription(ctx, sub.JID, cmdCtx.host, cmdCtx.nodeID) + default: + _ = x.router.Route(ctx, iq.BadRequestError()) + return + } + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + } + log.Infof("pep: modified subscriptions (host: %s, node_id: %s)", cmdCtx.host, cmdCtx.nodeID) + + _ = x.router.Route(ctx, iq.ResultIQ()) +} + +func (x *Pep) notifyOwners(ctx context.Context, notificationElem xmpp.XElement, affiliations []pubsubmodel.Affiliation, host, notificationType string) { + hostJID, _ := jid.NewWithString(host, true) + for _, affiliation := range affiliations { + if affiliation.Affiliation != pubsubmodel.Owner { + continue + } + toJID, _ := jid.NewWithString(affiliation.JID, true) + eventMsg := eventMessage(notificationElem, hostJID, toJID, notificationType) + + _ = x.router.Route(ctx, eventMsg) + } +} + +func (x *Pep) notifySubscribers( + ctx context.Context, + notificationElem xmpp.XElement, + subscribers []pubsubmodel.Subscription, + accessChecker *accessChecker, + host string, + nodeID string, + notificationType string, +) { + var toJIDs []jid.JID + for _, subscriber := range subscribers { + if subscriber.Subscription != pubsubmodel.Subscribed { + continue + } + subscriberJID, _ := jid.NewWithString(subscriber.JID, true) + toJIDs = append(toJIDs, *subscriberJID) + } + x.notify(ctx, notificationElem, toJIDs, accessChecker, host, nodeID, notificationType) +} + +func (x *Pep) notify( + ctx context.Context, + notificationElem xmpp.XElement, + toJIDs []jid.JID, + accessChecker *accessChecker, + host string, + nodeID string, + notificationType string, +) { + hostJID, _ := jid.NewWithString(host, true) + for _, toJID := range toJIDs { + if toJID.ToBareJID().String() != host { + // check JID access before notifying + err := accessChecker.checkAccess(ctx, toJID.ToBareJID().String()) + switch err { + case nil: + break + case errPresenceSubscriptionRequired, errNotInRosterGroup, errNotOnWhiteList: + continue + default: + log.Error(err) + continue + } + } + + if ph := x.entityCaps; ph != nil { + onlinePresences, err := ph.PresencesMatchingJID(ctx, &toJID) + if err != nil { + log.Error(err) + } + + for _, onlinePresence := range onlinePresences { + caps := onlinePresence.Caps + if caps == nil { + goto broadcastEventMsg // broadcast when caps are pending to be fetched + } + if !caps.HasFeature(nodeID + "+notify") { + continue + } + // notify to full jid + presence := onlinePresence.Presence + + eventMsg := eventMessage(notificationElem, hostJID, presence.FromJID(), notificationType) + _ = x.router.Route(ctx, eventMsg) + } + return + } + broadcastEventMsg: + // broadcast event message + eventMsg := eventMessage(notificationElem, hostJID, &toJID, notificationType) + _ = x.router.Route(ctx, eventMsg) + } +} + +func (x *Pep) withCommandContext(ctx context.Context, opts commandOptions, cmdElem xmpp.XElement, iq *xmpp.IQ, fn func(cmdCtx *commandContext)) { + var cmdCtx commandContext + + nodeID := cmdElem.Attributes().Get("node") + if len(nodeID) == 0 { + _ = x.router.Route(ctx, nodeIDRequiredError(iq)) + return + } + fromJID := iq.FromJID().ToBareJID().String() + host := iq.ToJID().ToBareJID().String() + + cmdCtx.host = host + cmdCtx.nodeID = nodeID + cmdCtx.isAccountOwner = fromJID == host + + // fetch node + node, err := x.pubSubRep.FetchNode(ctx, host, nodeID) + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + if node == nil { + if opts.failOnNotFound { + _ = x.router.Route(ctx, iq.ItemNotFoundError()) + } else { + fn(&cmdCtx) + } + return + } + cmdCtx.node = node + + // fetch affiliation + aff, err := x.pubSubRep.FetchNodeAffiliation(ctx, host, nodeID, fromJID) + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + cmdCtx.accessChecker = &accessChecker{ + host: node.Host, + nodeID: node.Name, + accessModel: node.Options.AccessModel, + rosterAllowedGroups: node.Options.RosterGroupsAllowed, + affiliation: aff, + rosterRep: x.rosterRep, + } + // check access + if opts.checkAccess && !cmdCtx.isAccountOwner { + err := cmdCtx.accessChecker.checkAccess(ctx, fromJID) + switch err { + case nil: + break + + case errOutcastMember: + _ = x.router.Route(ctx, iq.ForbiddenError()) + return + + case errPresenceSubscriptionRequired: + _ = x.router.Route(ctx, presenceSubscriptionRequiredError(iq)) + return + + case errNotInRosterGroup: + _ = x.router.Route(ctx, notInRosterGroupError(iq)) + return + + case errNotOnWhiteList: + _ = x.router.Route(ctx, notOnWhitelistError(iq)) + return + + default: + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + } + // validate affiliation + if len(opts.allowedAffiliations) > 0 { + var allowed bool + for _, allowedAff := range opts.allowedAffiliations { + if aff != nil && aff.Affiliation == allowedAff { + allowed = true + break + } + } + if !allowed { + _ = x.router.Route(ctx, iq.ForbiddenError()) + return + } + } + // fetch subscriptions + if opts.includeSubscriptions { + subscriptions, err := x.pubSubRep.FetchNodeSubscriptions(ctx, host, nodeID) + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + cmdCtx.subscriptions = subscriptions + } + // fetch affiliations + if opts.includeAffiliations { + affiliations, err := x.pubSubRep.FetchNodeAffiliations(ctx, host, nodeID) + if err != nil { + log.Error(err) + _ = x.router.Route(ctx, iq.InternalServerError()) + return + } + cmdCtx.affiliations = affiliations + } + fn(&cmdCtx) +} + +func (x *Pep) createNode(ctx context.Context, node *pubsubmodel.Node) error { + // create node + if err := x.pubSubRep.UpsertNode(ctx, node); err != nil { + return err + } + // create owner affiliation + ownerAffiliation := &pubsubmodel.Affiliation{ + JID: node.Host, + Affiliation: pubsubmodel.Owner, + } + if err := x.pubSubRep.UpsertNodeAffiliation(ctx, ownerAffiliation, node.Host, node.Name); err != nil { + return err + } + // create owner subscription + ownerSub := &pubsubmodel.Subscription{ + SubID: subscriptionID(node.Host, node.Host, node.Name), + JID: node.Host, + Subscription: pubsubmodel.Subscribed, + } + if err := x.pubSubRep.UpsertNodeSubscription(ctx, ownerSub, node.Host, node.Name); err != nil { + return err + } + // auto-subscribe roster members + j, err := jid.NewWithString(node.Host, true) + if err != nil { + return err + } + rosterItems, _, err := x.rosterRep.FetchRosterItems(ctx, j.Node()) + if err != nil { + return err + } + for _, ri := range rosterItems { + if ri.Subscription != rostermodel.SubscriptionBoth && ri.Subscription != rostermodel.SubscriptionFrom { + continue + } + subJID, _ := jid.NewWithString(ri.JID, true) + if err := x.subscribeTo(ctx, node, subJID); err != nil { + return err + } + } + x.registerDiscoItems(ctx) + return nil +} + +func (x *Pep) sendLastPublishedItem(ctx context.Context, toJID *jid.JID, accessChecker *accessChecker, host, nodeID, notificationType string) error { + node, err := x.pubSubRep.FetchNode(ctx, host, nodeID) + if err != nil { + return err + } + if node == nil { + return nil + } + lastItem, err := x.pubSubRep.FetchNodeLastItem(ctx, host, nodeID) + if err != nil { + return err + } + if lastItem == nil { + return nil + } + itemsEl := xmpp.NewElementName("items") + itemsEl.SetAttribute("node", nodeID) + itemEl := xmpp.NewElementName("item") + itemEl.SetAttribute("id", lastItem.ID) + if node.Options.DeliverPayloads || !node.Options.PersistItems { + itemEl.AppendElement(lastItem.Payload) + } + itemsEl.AppendElement(itemEl) + + x.notify( + ctx, + itemsEl, + []jid.JID{*toJID}, + accessChecker, + host, + nodeID, + notificationType) + return nil +} + +func eventMessage(payloadElem xmpp.XElement, hostJID, toJID *jid.JID, notificationType string) *xmpp.Message { + msg := xmpp.NewMessageType(uuid.New().String(), notificationType) + msg.SetFromJID(hostJID) + msg.SetToJID(toJID) + eventElem := xmpp.NewElementNamespace("event", pubSubEventNamespace) + eventElem.AppendElement(payloadElem) + msg.AppendElement(eventElem) + + return msg +} + +func nodeIDRequiredError(stanza xmpp.Stanza) xmpp.Stanza { + errorElements := []xmpp.XElement{xmpp.NewElementNamespace("nodeid-required", pubSubErrorNamespace)} + return xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrNotAcceptable, errorElements) +} + +func invalidPayloadError(stanza xmpp.Stanza) xmpp.Stanza { + errorElements := []xmpp.XElement{xmpp.NewElementNamespace("invalid-payload", pubSubErrorNamespace)} + return xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrBadRequest, errorElements) +} + +func invalidJIDError(stanza xmpp.Stanza) xmpp.Stanza { + errorElements := []xmpp.XElement{xmpp.NewElementNamespace("invalid-jid", pubSubErrorNamespace)} + return xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrBadRequest, errorElements) +} + +func presenceSubscriptionRequiredError(stanza xmpp.Stanza) xmpp.Stanza { + errorElements := []xmpp.XElement{xmpp.NewElementNamespace("presence-subscription-required", pubSubErrorNamespace)} + return xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrNotAuthorized, errorElements) +} + +func notInRosterGroupError(stanza xmpp.Stanza) xmpp.Stanza { + errorElements := []xmpp.XElement{xmpp.NewElementNamespace("not-in-roster-group", pubSubErrorNamespace)} + return xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrNotAuthorized, errorElements) +} + +func notOnWhitelistError(stanza xmpp.Stanza) xmpp.Stanza { + errorElements := []xmpp.XElement{xmpp.NewElementNamespace("closed-node", pubSubErrorNamespace)} + return xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrNotAllowed, errorElements) +} + +func notSubscribedError(stanza xmpp.Stanza) xmpp.Stanza { + errorElements := []xmpp.XElement{xmpp.NewElementNamespace("not-subscribed", pubSubErrorNamespace)} + return xmpp.NewErrorStanzaFromStanza(stanza, xmpp.ErrUnexpectedRequest, errorElements) +} + +func subscriptionID(jid, host, name string) string { + h := sha256.New() + h.Write([]byte(jid + host + name)) + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/module/xep0163/pep_test.go b/module/xep0163/pep_test.go new file mode 100755 index 000000000..c1121b508 --- /dev/null +++ b/module/xep0163/pep_test.go @@ -0,0 +1,958 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package xep0163 + +import ( + "context" + "crypto/tls" + "testing" + + c2srouter "github.com/ortuman/jackal/c2s/router" + capsmodel "github.com/ortuman/jackal/model/capabilities" + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/module/xep0004" + "github.com/ortuman/jackal/module/xep0115" + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/router/host" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" + "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/pborman/uuid" + "github.com/stretchr/testify/require" +) + +func TestXEP0163_Matching(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm := stream.NewMockC2S(uuid.New(), j) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) + + p := New(nil, nil, r, rosterRep, pubSubRep) + + // test MatchesIQ + iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) + iq.SetFromJID(j) + iq.SetToJID(j.ToBareJID()) + iq.AppendElement(xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace)) + + require.True(t, p.MatchesIQ(iq)) +} + +func TestXEP163_CreateNode(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm := stream.NewMockC2S(uuid.New(), j) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) + + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j) + iq.SetToJID(j.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + create := xmpp.NewElementName("create") + create.SetAttribute("node", "princely_musings") + pubSub.AppendElement(create) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, iqID, elem.ID()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + // read node + n, _ := pubSubRep.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + require.NotNil(t, n) + require.Equal(t, n.Options, defaultNodeOptions) +} + +func TestXEP163_GetNodeConfiguration(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm := stream.NewMockC2S(uuid.New(), j) + stm.SetPresence(xmpp.NewPresence(j, j, xmpp.AvailableType)) + + r.Bind(context.Background(), stm) + + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.GetType) + iq.SetFromJID(j) + iq.SetToJID(j.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + configureElem := xmpp.NewElementName("configure") + configureElem.SetAttribute("node", "princely_musings") + pubSub.AppendElement(configureElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, iqID, elem.ID()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + // get form element + pubSubRes := elem.Elements().ChildNamespace("pubsub", pubSubOwnerNamespace) + require.NotNil(t, pubSubRes) + configElem := pubSubRes.Elements().Child("configure") + require.NotNil(t, configElem) + formEl := configElem.Elements().ChildNamespace("x", xep0004.FormNamespace) + require.NotNil(t, formEl) + + configForm, err := xep0004.NewFormFromElement(formEl) + require.Nil(t, err) + require.Equal(t, xep0004.Form, configForm.Type) +} + +func TestXEP163_SetNodeConfiguration(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm2 := stream.NewMockC2S(uuid.New(), j2) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + r.Bind(context.Background(), stm2) + + nodeOpts := defaultNodeOptions + nodeOpts.NotifyConfig = true + + // create node and affiliations + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: nodeOpts, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + _ = pubSubRep.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + JID: "ortuman@jackal.im", + Subscription: pubsubmodel.Subscribed, + }, "ortuman@jackal.im", "princely_musings") + + _ = pubSubRep.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + JID: "noelia@jackal.im", + Subscription: pubsubmodel.Subscribed, + }, "ortuman@jackal.im", "princely_musings") + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Subscription: "both", + }) + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j1) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + configureElem := xmpp.NewElementName("configure") + configureElem.SetAttribute("node", "princely_musings") + + // attach config update + nodeOpts.Title = "a fancy new title" + + configForm := nodeOpts.ResultForm() + configForm.Type = xep0004.Submit + configureElem.AppendElement(configForm.Element()) + + pubSub.AppendElement(configureElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + + elem := stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "message", elem.Name()) // notification + require.NotNil(t, elem.Elements().ChildNamespace("event", pubSubEventNamespace)) + + elem2 := stm2.ReceiveElement() + require.NotNil(t, elem2) + require.Equal(t, "message", elem.Name()) // notification + eventElem := elem2.Elements().ChildNamespace("event", pubSubEventNamespace) + require.NotNil(t, eventElem) + + configElemResp := eventElem.Elements().Child("configuration") + require.NotNil(t, configElemResp) + require.Equal(t, "princely_musings", configElemResp.Attributes().Get("node")) + + // result IQ + elem = stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, iqID, elem.ID()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + // check if configuration was applied + n, _ := pubSubRep.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + require.NotNil(t, n) + require.Equal(t, nodeOpts.Title, n.Options.Title) +} + +func TestXEP163_DeleteNode(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm2 := stream.NewMockC2S(uuid.New(), j2) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + r.Bind(context.Background(), stm2) + + nodeOpts := defaultNodeOptions + nodeOpts.NotifyDelete = true + + // create node and affiliations + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: nodeOpts, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + _ = pubSubRep.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + JID: "ortuman@jackal.im", + Subscription: pubsubmodel.Subscribed, + }, "ortuman@jackal.im", "princely_musings") + + _ = pubSubRep.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + JID: "noelia@jackal.im", + Subscription: pubsubmodel.Subscribed, + }, "ortuman@jackal.im", "princely_musings") + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Subscription: "both", + }) + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j1) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + deleteElem := xmpp.NewElementName("delete") + deleteElem.SetAttribute("node", "princely_musings") + pubSub.AppendElement(deleteElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "message", elem.Name()) // notification + require.NotNil(t, elem.Elements().ChildNamespace("event", pubSubEventNamespace)) + + elem2 := stm2.ReceiveElement() + require.NotNil(t, elem2) + require.Equal(t, "message", elem.Name()) // notification + eventElem := elem2.Elements().ChildNamespace("event", pubSubEventNamespace) + require.NotNil(t, eventElem) + + deleteElemResp := eventElem.Elements().Child("delete") + require.NotNil(t, deleteElemResp) + require.Equal(t, "princely_musings", deleteElemResp.Attributes().Get("node")) + + // result IQ + elem = stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, iqID, elem.ID()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + // read node + n, _ := pubSubRep.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, n) +} + +func TestXEP163_UpdateAffiliations(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + + // create node + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + // create new affiliation + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j1) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + affElem := xmpp.NewElementName("affiliations") + affElem.SetAttribute("node", "princely_musings") + + affiliation := xmpp.NewElementName("affiliation") + affiliation.SetAttribute("jid", "noelia@jackal.im") + affiliation.SetAttribute("affiliation", pubsubmodel.Owner) + affElem.AppendElement(affiliation) + pubSub.AppendElement(affElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + aff, _ := pubSubRep.FetchNodeAffiliation(context.Background(), "ortuman@jackal.im", "princely_musings", "noelia@jackal.im") + require.NotNil(t, aff) + require.Equal(t, "noelia@jackal.im", aff.JID) + require.Equal(t, pubsubmodel.Owner, aff.Affiliation) + + // remove affiliation + affiliation.SetAttribute("affiliation", pubsubmodel.None) + + p.ProcessIQ(context.Background(), iq) + elem = stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + aff, _ = pubSubRep.FetchNodeAffiliation(context.Background(), "ortuman@jackal.im", "princely_musings", "noelia@jackal.im") + require.Nil(t, aff) +} + +func TestXEP163_RetrieveAffiliations(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + + // create node and affiliations + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "noelia@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.GetType) + iq.SetFromJID(j1) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + affElem := xmpp.NewElementName("affiliations") + affElem.SetAttribute("node", "princely_musings") + pubSub.AppendElement(affElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + + pubSubElem := elem.Elements().ChildNamespace("pubsub", pubSubOwnerNamespace) + require.NotNil(t, pubSubElem) + + affiliationsElem := pubSubElem.Elements().Child("affiliations") + require.NotNil(t, affiliationsElem) + + affiliations := affiliationsElem.Elements().Children("affiliation") + require.Len(t, affiliations, 2) + + require.Equal(t, "ortuman@jackal.im", affiliations[0].Attributes().Get("jid")) + require.Equal(t, pubsubmodel.Owner, affiliations[0].Attributes().Get("affiliation")) + require.Equal(t, "noelia@jackal.im", affiliations[1].Attributes().Get("jid")) + require.Equal(t, pubsubmodel.Owner, affiliations[1].Attributes().Get("affiliation")) +} + +func TestXEP163_UpdateSubscriptions(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + + // create node + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + // create new subscription + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j1) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + subElem := xmpp.NewElementName("subscriptions") + subElem.SetAttribute("node", "princely_musings") + + sub := xmpp.NewElementName("subscription") + sub.SetAttribute("jid", "noelia@jackal.im") + sub.SetAttribute("subscription", pubsubmodel.Subscribed) + subElem.AppendElement(sub) + pubSub.AppendElement(subElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + subs, _ := pubSubRep.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.NotNil(t, subs) + require.Len(t, subs, 1) + require.Equal(t, "noelia@jackal.im", subs[0].JID) + require.Equal(t, pubsubmodel.Subscribed, subs[0].Subscription) + + // remove subscription + sub.SetAttribute("subscription", pubsubmodel.None) + + p.ProcessIQ(context.Background(), iq) + elem = stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + subs, _ = pubSubRep.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, subs) +} + +func TestXEP163_RetrieveSubscriptions(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + + // create node and affiliations + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + _ = pubSubRep.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + SubID: uuid.New(), + JID: "noelia@jackal.im", + Subscription: pubsubmodel.Subscribed, + }, "ortuman@jackal.im", "princely_musings") + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.GetType) + iq.SetFromJID(j1) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubOwnerNamespace) + affElem := xmpp.NewElementName("subscriptions") + affElem.SetAttribute("node", "princely_musings") + pubSub.AppendElement(affElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + + pubSubElem := elem.Elements().ChildNamespace("pubsub", pubSubOwnerNamespace) + require.NotNil(t, pubSubElem) + + subscriptionsElem := pubSubElem.Elements().Child("subscriptions") + require.NotNil(t, subscriptionsElem) + + subscriptions := subscriptionsElem.Elements().Children("subscription") + require.Len(t, subscriptions, 1) + + require.Equal(t, "noelia@jackal.im", subscriptions[0].Attributes().Get("jid")) + require.Equal(t, pubsubmodel.Subscribed, subscriptions[0].Attributes().Get("subscription")) +} + +func TestXEP163_Subscribe(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm2 := stream.NewMockC2S(uuid.New(), j2) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + r.Bind(context.Background(), stm2) + + // create node and affiliations + nodeOpts := defaultNodeOptions + nodeOpts.NotifySub = true + + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: nodeOpts, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Subscription: "both", + }) + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j2) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + subElem := xmpp.NewElementName("subscribe") + subElem.SetAttribute("node", "princely_musings") + subElem.SetAttribute("jid", "noelia@jackal.im") + pubSub.AppendElement(subElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm2.ReceiveElement() + + // command reply + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + pubSubElem := elem.Elements().ChildNamespace("pubsub", pubSubNamespace) + require.NotNil(t, pubSubElem) + subscriptionElem := pubSubElem.Elements().Child("subscription") + require.NotNil(t, subscriptionElem) + require.Equal(t, "noelia@jackal.im", subscriptionElem.Attributes().Get("jid")) + require.Equal(t, "subscribed", subscriptionElem.Attributes().Get("subscription")) + require.Equal(t, "princely_musings", subscriptionElem.Attributes().Get("node")) + + // subscription notification + elem = stm1.ReceiveElement() + require.NotNil(t, elem) + require.Equal(t, "message", elem.Name()) + + eventElem := elem.Elements().ChildNamespace("event", pubSubEventNamespace) + require.NotNil(t, eventElem) + + subscriptionElem = eventElem.Elements().Child("subscription") + require.NotNil(t, subscriptionElem) + require.Equal(t, "noelia@jackal.im", subscriptionElem.Attributes().Get("jid")) + require.Equal(t, "subscribed", subscriptionElem.Attributes().Get("subscription")) + require.Equal(t, "princely_musings", subscriptionElem.Attributes().Get("node")) + + // check storage subscription + subs, _ := pubSubRep.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Len(t, subs, 1) + require.Equal(t, "noelia@jackal.im", subs[0].JID) + require.Equal(t, pubsubmodel.Subscribed, subs[0].Subscription) +} + +func TestXEP163_Unsubscribe(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "balcony", true) + + stm2 := stream.NewMockC2S(uuid.New(), j2) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + + r.Bind(context.Background(), stm2) + + // create node and affiliations + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Subscription: "both", + }) + + _ = pubSubRep.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + SubID: uuid.New(), + JID: "noelia@jackal.im", + Subscription: pubsubmodel.Subscribed, + }, "ortuman@jackal.im", "princely_musings") + + // process pubsub command + p := New(nil, nil, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j2) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + subElem := xmpp.NewElementName("unsubscribe") + subElem.SetAttribute("node", "princely_musings") + subElem.SetAttribute("jid", "noelia@jackal.im") + pubSub.AppendElement(subElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm2.ReceiveElement() + + // command reply + require.NotNil(t, elem) + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + // check storage subscription + subs, _ := pubSubRep.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Len(t, subs, 0) +} + +func TestXEP163_RetrieveItems(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "balcony", true) + + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm2 := stream.NewMockC2S(uuid.New(), j2) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + r.Bind(context.Background(), stm1) + r.Bind(context.Background(), stm2) + + // create node and affiliations + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Subscription: "both", + }) + + // create items + _ = pubSubRep.UpsertNodeItem(context.Background(), &pubsubmodel.Item{ + ID: "i1", + Publisher: "noelia@jackal.im", + Payload: xmpp.NewElementName("m1"), + }, "ortuman@jackal.im", "princely_musings", 2) + + _ = pubSubRep.UpsertNodeItem(context.Background(), &pubsubmodel.Item{ + ID: "i2", + Publisher: "noelia@jackal.im", + Payload: xmpp.NewElementName("m2"), + }, "ortuman@jackal.im", "princely_musings", 2) + + p := New(nil, nil, r, rosterRep, pubSubRep) + + // retrieve all items + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.GetType) + iq.SetFromJID(j2) + iq.SetToJID(j1.ToBareJID()) + + pubSub := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + itemsCmdElem := xmpp.NewElementName("items") + itemsCmdElem.SetAttribute("node", "princely_musings") + pubSub.AppendElement(itemsCmdElem) + iq.AppendElement(pubSub) + + p.ProcessIQ(context.Background(), iq) + elem := stm2.ReceiveElement() + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + pubSubElem := elem.Elements().ChildNamespace("pubsub", pubSubNamespace) + require.NotNil(t, pubSubElem) + itemsElem := pubSubElem.Elements().Child("items") + require.NotNil(t, itemsElem) + items := itemsElem.Elements().Children("item") + require.Len(t, items, 2) + + require.Equal(t, "i1", items[0].Attributes().Get("id")) + require.Equal(t, "i2", items[1].Attributes().Get("id")) + + // retrieve item i2 + i2Elem := xmpp.NewElementName("item") + i2Elem.SetAttribute("id", "i2") + itemsCmdElem.AppendElement(i2Elem) + + p.ProcessIQ(context.Background(), iq) + elem = stm2.ReceiveElement() + require.Equal(t, "iq", elem.Name()) + require.Equal(t, xmpp.ResultType, elem.Type()) + + pubSubElem = elem.Elements().ChildNamespace("pubsub", pubSubNamespace) + require.NotNil(t, pubSubElem) + itemsElem = pubSubElem.Elements().Child("items") + require.NotNil(t, itemsElem) + items = itemsElem.Elements().Children("item") + require.Len(t, items, 1) + + require.Equal(t, "i2", items[0].Attributes().Get("id")) +} + +func TestXEP163_SubscribeToAll(t *testing.T) { + r, _, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + + // create node and affiliations + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "noelia@jackal.im", + Name: "princely_musings_1", + Options: defaultNodeOptions, + }) + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "noelia@jackal.im", + Name: "princely_musings_2", + Options: defaultNodeOptions, + }) + _ = pubSubRep.UpsertNodeItem(context.Background(), &pubsubmodel.Item{ + ID: "i2", + Publisher: "noelia@jackal.im", + Payload: xmpp.NewElementName("m2"), + }, "noelia@jackal.im", "princely_musings_2", 2) + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "noelia", + JID: "ortuman@jackal.im", + Subscription: "both", + }) + p := New(nil, nil, r, rosterRep, pubSubRep) + + err := p.subscribeToAll(context.Background(), "noelia@jackal.im", j1) + require.Nil(t, err) + + nodes, _ := pubSubRep.FetchSubscribedNodes(context.Background(), j1.ToBareJID().String()) + require.NotNil(t, nodes) + require.Len(t, nodes, 2) + + err = p.unsubscribeFromAll(context.Background(), "noelia@jackal.im", j1) + require.Nil(t, err) + + nodes, _ = pubSubRep.FetchSubscribedNodes(context.Background(), j1.ToBareJID().String()) + require.Nil(t, nodes) +} + +func TestXEP163_FilteredNotifications(t *testing.T) { + r, presencesRep, rosterRep, pubSubRep := setupTest("jackal.im") + + j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) + j2, _ := jid.New("noelia", "jackal.im", "balcony", true) + stm1 := stream.NewMockC2S(uuid.New(), j1) + stm2 := stream.NewMockC2S(uuid.New(), j2) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + r.Bind(context.Background(), stm1) + r.Bind(context.Background(), stm2) + + // create node, affiliations and subscriptions + _ = pubSubRep.UpsertNode(context.Background(), &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + Options: defaultNodeOptions, + }) + + _ = pubSubRep.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: pubsubmodel.Owner, + }, "ortuman@jackal.im", "princely_musings") + + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ + Username: "ortuman", + JID: "noelia@jackal.im", + Subscription: "both", + }) + + _ = pubSubRep.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + SubID: uuid.New(), + JID: "noelia@jackal.im", + Subscription: pubsubmodel.Subscribed, + }, "ortuman@jackal.im", "princely_musings") + + // set capabilities + _ = presencesRep.UpsertCapabilities(context.Background(), &capsmodel.Capabilities{ + Node: "http://code.google.com/p/exodus", + Ver: "QgayPKawpkPSDYmwT/WM94uAlu0=", + Features: []string{"princely_musings+notify"}, + }) + caps := xep0115.New(r, presencesRep, "alloc-1234") + + // register presence + pr2 := xmpp.NewPresence(j2, j2, xmpp.AvailableType) + c := xmpp.NewElementNamespace("c", "http://jabber.org/protocol/caps") + c.SetAttribute("hash", "sha-1") + c.SetAttribute("node", "http://code.google.com/p/exodus") + c.SetAttribute("ver", "QgayPKawpkPSDYmwT/WM94uAlu0=") + pr2.AppendElement(c) + + _, _ = caps.RegisterPresence(context.Background(), pr2) + + // process pubsub command + p := New(nil, caps, r, rosterRep, pubSubRep) + + iqID := uuid.New() + iq := xmpp.NewIQType(iqID, xmpp.SetType) + iq.SetFromJID(j1) + iq.SetToJID(j1.ToBareJID()) + + pubSubEl := xmpp.NewElementNamespace("pubsub", pubSubNamespace) + publishEl := xmpp.NewElementName("publish") + publishEl.SetAttribute("node", "princely_musings") + itemEl := xmpp.NewElementName("item") + itemEl.SetAttribute("id", "bnd81g37d61f49fgn581") + entryEl := xmpp.NewElementNamespace("entry", "http://www.w3.org/2005/Atom") + itemEl.AppendElement(entryEl) + publishEl.AppendElement(itemEl) + pubSubEl.AppendElement(publishEl) + + iq.AppendElement(pubSubEl) + + p.ProcessIQ(context.Background(), iq) + elem := stm2.ReceiveElement() + require.Equal(t, "message", elem.Name()) + require.Equal(t, xmpp.HeadlineType, elem.Type()) + + eventEl := elem.Elements().ChildNamespace("event", pubSubEventNamespace) + require.NotNil(t, eventEl) + + itemsEl := eventEl.Elements().Child("items") + require.NotNil(t, itemsEl) + + require.Equal(t, "bnd81g37d61f49fgn581", itemsEl.Elements().Child("item").Attributes().Get("id")) +} + +func setupTest(domain string) (router.Router, repository.Presences, repository.Roster, repository.PubSub) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + presencesRep := memorystorage.NewPresences() + rosterRep := memorystorage.NewRoster() + pubSubRep := memorystorage.NewPubSub() + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r, presencesRep, rosterRep, pubSubRep +} diff --git a/module/xep0191/block_list.go b/module/xep0191/block_list.go index 02f954b91..7c51410fa 100644 --- a/module/xep0191/block_list.go +++ b/module/xep0191/block_list.go @@ -6,15 +6,17 @@ package xep0191 import ( + "context" + "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/rostermodel" - "github.com/ortuman/jackal/module/roster" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/module/xep0030" + "github.com/ortuman/jackal/module/xep0115" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/pborman/uuid" @@ -26,19 +28,23 @@ const ( xep191RequestedContextKey = "xep_191:requested" ) -// BlockingCommand returns a blocking command IQ handler module. +// BlockingCommand represents a blocking command IQ handler module. type BlockingCommand struct { - router *router.Router - roster *roster.Roster - runQueue *runqueue.RunQueue + runQueue *runqueue.RunQueue + router router.Router + blockListRep repository.BlockList + rosterRep repository.Roster + entityCaps *xep0115.EntityCaps } // New returns a blocking command IQ handler module. -func New(disco *xep0030.DiscoInfo, roster *roster.Roster, router *router.Router) *BlockingCommand { +func New(disco *xep0030.DiscoInfo, entityCaps *xep0115.EntityCaps, router router.Router, rosterRep repository.Roster, blockListRep repository.BlockList) *BlockingCommand { b := &BlockingCommand{ - router: router, - roster: roster, - runQueue: runqueue.New("xep0191"), + runQueue: runqueue.New("xep0191"), + router: router, + blockListRep: blockListRep, + rosterRep: rosterRep, + entityCaps: entityCaps, } if disco != nil { disco.RegisterServerFeature(blockingCommandNamespace) @@ -47,8 +53,7 @@ func New(disco *xep0030.DiscoInfo, roster *roster.Roster, router *router.Router) return b } -// MatchesIQ returns whether or not an IQ should be -// processed by the blocking command module. +// MatchesIQ returns whether or not an IQ should be processed by the blocking command module. func (x *BlockingCommand) MatchesIQ(iq *xmpp.IQ) bool { e := iq.Elements() blockList := e.ChildNamespace("blocklist", blockingCommandNamespace) @@ -57,15 +62,14 @@ func (x *BlockingCommand) MatchesIQ(iq *xmpp.IQ) bool { return (iq.IsGet() && blockList != nil) || (iq.IsSet() && (block != nil || unblock != nil)) } -// ProcessIQ processes a blocking command IQ -// taking according actions over the associated stream. -func (x *BlockingCommand) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes a blocking command IQ taking according actions over the associated stream. +func (x *BlockingCommand) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - stm := x.router.UserStream(iq.FromJID()) + stm := x.router.LocalStream(iq.FromJID().Node(), iq.FromJID().Resource()) if stm == nil { return } - x.processIQ(iq, stm) + x.processIQ(ctx, iq, stm) }) } @@ -77,147 +81,160 @@ func (x *BlockingCommand) Shutdown() error { return nil } -func (x *BlockingCommand) processIQ(iq *xmpp.IQ, stm stream.C2S) { +func (x *BlockingCommand) processIQ(ctx context.Context, iq *xmpp.IQ, stm stream.C2S) { if iq.IsGet() { - x.sendBlockList(iq, stm) + x.sendBlockList(ctx, iq, stm) } else if iq.IsSet() { e := iq.Elements() if block := e.ChildNamespace("block", blockingCommandNamespace); block != nil { - x.block(iq, block, stm) + x.block(ctx, iq, block, stm) } else if unblock := e.ChildNamespace("unblock", blockingCommandNamespace); unblock != nil { - x.unblock(iq, unblock, stm) + x.unblock(ctx, iq, unblock, stm) } } } -func (x *BlockingCommand) sendBlockList(iq *xmpp.IQ, stm stream.C2S) { +func (x *BlockingCommand) sendBlockList(ctx context.Context, iq *xmpp.IQ, stm stream.C2S) { fromJID := iq.FromJID() - blItms, err := storage.FetchBlockListItems(fromJID.Node()) + blItems, err := x.blockListRep.FetchBlockListItems(ctx, fromJID.Node()) if err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } blockList := xmpp.NewElementNamespace("blocklist", blockingCommandNamespace) - for _, blItm := range blItms { + for _, blItem := range blItems { itElem := xmpp.NewElementName("item") - itElem.SetAttribute("jid", blItm.JID) + itElem.SetAttribute("jid", blItem.JID) blockList.AppendElement(itElem) } + stm.SetValue(xep191RequestedContextKey, true) + reply := iq.ResultIQ() reply.AppendElement(blockList) - stm.SendElement(reply) - - stm.SetBool(xep191RequestedContextKey, true) + stm.SendElement(ctx, reply) } -func (x *BlockingCommand) block(iq *xmpp.IQ, block xmpp.XElement, stm stream.C2S) { - var bl []model.BlockListItem - +func (x *BlockingCommand) block(ctx context.Context, iq *xmpp.IQ, block xmpp.XElement, stm stream.C2S) { items := block.Elements().Children("item") if len(items) == 0 { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) return } jds, err := x.extractItemJIDs(items) if err != nil { log.Error(err) - stm.SendElement(iq.JidMalformedError()) + stm.SendElement(ctx, iq.JidMalformedError()) return } - blItems, ris, err := x.fetchBlockListAndRosterItems(stm) + blItems, ris, err := x.fetchBlockListAndRosterItems(ctx, stm.Username()) if err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } username := stm.Username() for _, j := range jds { - if !x.isJIDInBlockList(j, blItems) { - x.broadcastPresenceMatchingJID(j, ris, xmpp.UnavailableType, stm) - bl = append(bl, model.BlockListItem{Username: username, JID: j.String()}) + if x.isJIDInBlockList(j, blItems) { + continue } + err := x.blockListRep.InsertBlockListItem(ctx, &model.BlockListItem{ + Username: username, + JID: j.String(), + }) + if err != nil { + log.Error(err) + stm.SendElement(ctx, iq.InternalServerError()) + return + } + x.broadcastPresenceMatchingJID(ctx, j, ris, xmpp.UnavailableType, stm) } - if err := storage.InsertBlockListItems(bl); err != nil { - log.Error(err) - stm.SendElement(iq.InternalServerError()) - return - } - x.router.ReloadBlockList(username) - stm.SendElement(iq.ResultIQ()) - x.pushIQ(block, stm) + stm.SendElement(ctx, iq.ResultIQ()) + x.pushIQ(ctx, block, stm) } -func (x *BlockingCommand) unblock(iq *xmpp.IQ, unblock xmpp.XElement, stm stream.C2S) { +func (x *BlockingCommand) unblock(ctx context.Context, iq *xmpp.IQ, unblock xmpp.XElement, stm stream.C2S) { items := unblock.Elements().Children("item") jds, err := x.extractItemJIDs(items) if err != nil { log.Error(err) - stm.SendElement(iq.JidMalformedError()) + stm.SendElement(ctx, iq.JidMalformedError()) return } - blItems, ris, err := x.fetchBlockListAndRosterItems(stm) + username := stm.Username() + + blItems, ris, err := x.fetchBlockListAndRosterItems(ctx, username) if err != nil { log.Error(err) - stm.SendElement(iq.InternalServerError()) + stm.SendElement(ctx, iq.InternalServerError()) return } - username := stm.Username() - var bl []model.BlockListItem - if len(jds) == 0 { + if len(jds) > 0 { + for _, j := range jds { + if !x.isJIDInBlockList(j, blItems) { + continue + } + if err := x.blockListRep.DeleteBlockListItem(ctx, &model.BlockListItem{ + Username: username, + JID: j.String(), + }); err != nil { + log.Error(err) + stm.SendElement(ctx, iq.InternalServerError()) + return + } + x.broadcastPresenceMatchingJID(ctx, j, ris, xmpp.AvailableType, stm) + } + } else { // remove all block list items for _, blItem := range blItems { + if err := x.blockListRep.DeleteBlockListItem(ctx, &blItem); err != nil { + log.Error(err) + stm.SendElement(ctx, iq.InternalServerError()) + return + } j, _ := jid.NewWithString(blItem.JID, true) - x.broadcastPresenceMatchingJID(j, ris, xmpp.AvailableType, stm) - } - bl = blItems - } else { - for _, j := range jds { - if x.isJIDInBlockList(j, blItems) { - x.broadcastPresenceMatchingJID(j, ris, xmpp.AvailableType, stm) - bl = append(bl, model.BlockListItem{Username: username, JID: j.String()}) - } + x.broadcastPresenceMatchingJID(ctx, j, ris, xmpp.AvailableType, stm) } } - if err := storage.DeleteBlockListItems(bl); err != nil { - log.Error(err) - stm.SendElement(iq.InternalServerError()) - return - } - x.router.ReloadBlockList(username) - stm.SendElement(iq.ResultIQ()) - x.pushIQ(unblock, stm) + stm.SendElement(ctx, iq.ResultIQ()) + x.pushIQ(ctx, unblock, stm) } -func (x *BlockingCommand) pushIQ(elem xmpp.XElement, stm stream.C2S) { - stms := x.router.UserStreams(stm.Username()) - for _, stm := range stms { - if !stm.GetBool(xep191RequestedContextKey) { +func (x *BlockingCommand) pushIQ(ctx context.Context, elem xmpp.XElement, stm stream.C2S) { + streams := x.router.LocalStreams(stm.Username()) + for _, stm := range streams { + requested, _ := stm.Value(xep191RequestedContextKey).(bool) + if !requested { continue } iq := xmpp.NewIQType(uuid.New(), xmpp.SetType) iq.AppendElement(elem) - stm.SendElement(iq) + stm.SendElement(ctx, iq) } } -func (x *BlockingCommand) broadcastPresenceMatchingJID(jid *jid.JID, ris []rostermodel.Item, presenceType string, stm stream.C2S) { - if x.roster == nil { +func (x *BlockingCommand) broadcastPresenceMatchingJID(ctx context.Context, blockedJID *jid.JID, ris []rostermodel.Item, presenceType string, stm stream.C2S) { + if x.entityCaps == nil { // roster disabled return } - presences := x.roster.OnlinePresencesMatchingJID(jid) - for _, presence := range presences { + onlinePresences, err := x.entityCaps.PresencesMatchingJID(ctx, blockedJID) + if err != nil { + log.Error(err) + return + } + for _, onlinePresence := range onlinePresences { + presence := onlinePresence.Presence if !x.isSubscribedTo(presence.FromJID().ToBareJID(), ris) { continue } - p := xmpp.NewPresence(presence.FromJID(), stm.JID().ToBareJID(), presenceType) + p := xmpp.NewPresence(stm.JID(), presence.FromJID(), presenceType) if presenceType == xmpp.AvailableType { p.AppendElements(presence.Elements().All()) } - x.router.MustRoute(p) + _ = x.router.MustRoute(ctx, p) } } @@ -231,26 +248,24 @@ func (x *BlockingCommand) isJIDInBlockList(jid *jid.JID, blItems []model.BlockLi } func (x *BlockingCommand) isSubscribedTo(jid *jid.JID, ris []rostermodel.Item) bool { - str := jid.String() for _, ri := range ris { - if ri.JID == str && (ri.Subscription == rostermodel.SubscriptionTo || ri.Subscription == rostermodel.SubscriptionBoth) { - return true + if ri.JID == jid.String() { + return ri.Subscription == rostermodel.SubscriptionFrom || ri.Subscription == rostermodel.SubscriptionBoth } } return false } -func (x *BlockingCommand) fetchBlockListAndRosterItems(stm stream.C2S) ([]model.BlockListItem, []rostermodel.Item, error) { - username := stm.Username() - blItms, err := storage.FetchBlockListItems(username) +func (x *BlockingCommand) fetchBlockListAndRosterItems(ctx context.Context, username string) ([]model.BlockListItem, []rostermodel.Item, error) { + blItems, err := x.blockListRep.FetchBlockListItems(ctx, username) if err != nil { return nil, nil, err } - ris, _, err := storage.FetchRosterItems(username) + ris, _, err := x.rosterRep.FetchRosterItems(ctx, username) if err != nil { return nil, nil, err } - return blItms, ris, nil + return blItems, ris, nil } func (x *BlockingCommand) extractItemJIDs(items []xmpp.XElement) ([]*jid.JID, error) { diff --git a/module/xep0191/block_list_test.go b/module/xep0191/block_list_test.go index d62ef89b5..dd6d35b81 100644 --- a/module/xep0191/block_list_test.go +++ b/module/xep0191/block_list_test.go @@ -6,16 +6,19 @@ package xep0191 import ( + "context" "crypto/tls" "testing" "time" + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/rostermodel" - "github.com/ortuman/jackal/module/roster" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/module/xep0115" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + "github.com/ortuman/jackal/router/host" + memorystorage "github.com/ortuman/jackal/storage/memory" + "github.com/ortuman/jackal/storage/repository" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -24,19 +27,18 @@ import ( ) func TestXEP0191_Matching(t *testing.T) { - rtr, _, shutdown := setupTest("jackal.im") - defer shutdown() + r, presencesRep, blockListRep, rosterRep := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - rtr.Bind(stm) + r.Bind(context.Background(), stm) - r := roster.New(&roster.Config{}, rtr) - defer r.Shutdown() + ph := xep0115.New(r, presencesRep, "alloc-1234") + defer func() { _ = ph.Shutdown() }() - x := New(nil, r, rtr) - defer x.Shutdown() + x := New(nil, ph, r, rosterRep, blockListRep) + defer func() { _ = x.Shutdown() }() // test MatchesIQ iq1 := xmpp.NewIQType(uuid.New(), xmpp.GetType) @@ -59,57 +61,58 @@ func TestXEP0191_Matching(t *testing.T) { } func TestXEP0191_GetBlockList(t *testing.T) { - rtr, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, presencesRep, blockListRep, rosterRep := setupTest("jackal.im") j, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j) - rtr.Bind(stm) + r.Bind(context.Background(), stm) - r := roster.New(&roster.Config{}, rtr) - defer r.Shutdown() + ph := xep0115.New(r, presencesRep, "alloc-1234") + defer func() { _ = ph.Shutdown() }() - x := New(nil, r, rtr) - defer x.Shutdown() + x := New(nil, ph, r, rosterRep, blockListRep) + defer func() { _ = x.Shutdown() }() - storage.InsertBlockListItems([]model.BlockListItem{{ + _ = blockListRep.InsertBlockListItem(context.Background(), &model.BlockListItem{ Username: "ortuman", JID: "hamlet@jackal.im/garden", - }, { + }) + _ = blockListRep.InsertBlockListItem(context.Background(), &model.BlockListItem{ Username: "ortuman", JID: "jabber.org", - }}) + }) iq1 := xmpp.NewIQType(uuid.New(), xmpp.GetType) iq1.SetFromJID(j) iq1.SetToJID(j) iq1.AppendElement(xmpp.NewElementNamespace("blocklist", blockingCommandNamespace)) - x.ProcessIQ(iq1) + x.ProcessIQ(context.Background(), iq1) elem := stm.ReceiveElement() bl := elem.Elements().ChildNamespace("blocklist", blockingCommandNamespace) require.NotNil(t, bl) require.Equal(t, 2, len(bl.Elements().Children("item"))) - require.True(t, stm.GetBool(xep191RequestedContextKey)) + requested, _ := stm.Value(xep191RequestedContextKey).(bool) + require.True(t, requested) - s.EnableMockedError() - x.ProcessIQ(iq1) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq1) elem = stm.ReceiveElement() + require.Len(t, elem.Error().Elements().All(), 1) require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() } func TestXEP191_BlockAndUnblock(t *testing.T) { - rtr, s, shutdown := setupTest("jackal.im") - defer shutdown() + r, presencesRep, blockListRep, rosterRep := setupTest("jackal.im") - r := roster.New(&roster.Config{}, rtr) - defer r.Shutdown() + caps := xep0115.New(r, presencesRep, "alloc-1234") + defer func() { _ = caps.Shutdown() }() - x := New(nil, r, rtr) - defer x.Shutdown() + x := New(nil, caps, r, rosterRep, blockListRep) + defer func() { _ = x.Shutdown() }() j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm1 := stream.NewMockC2S(uuid.New(), j1) @@ -128,23 +131,28 @@ func TestXEP191_BlockAndUnblock(t *testing.T) { stm3.SetAuthenticated(true) stm4.SetAuthenticated(true) - rtr.Bind(stm1) - rtr.Bind(stm2) - rtr.Bind(stm3) - rtr.Bind(stm4) + stm1.SetPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + stm2.SetPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + stm3.SetPresence(xmpp.NewPresence(j3, j3, xmpp.AvailableType)) + stm4.SetPresence(xmpp.NewPresence(j4, j4, xmpp.AvailableType)) + + r.Bind(context.Background(), stm1) + r.Bind(context.Background(), stm2) + r.Bind(context.Background(), stm3) + r.Bind(context.Background(), stm4) // register presences - r.ProcessPresence(xmpp.NewPresence(j1, j1, xmpp.AvailableType)) - r.ProcessPresence(xmpp.NewPresence(j2, j2, xmpp.AvailableType)) - r.ProcessPresence(xmpp.NewPresence(j3, j3, xmpp.AvailableType)) - r.ProcessPresence(xmpp.NewPresence(j4, j4, xmpp.AvailableType)) + _, _ = caps.RegisterPresence(context.Background(), xmpp.NewPresence(j1, j1, xmpp.AvailableType)) + _, _ = caps.RegisterPresence(context.Background(), xmpp.NewPresence(j2, j2, xmpp.AvailableType)) + _, _ = caps.RegisterPresence(context.Background(), xmpp.NewPresence(j3, j3, xmpp.AvailableType)) + _, _ = caps.RegisterPresence(context.Background(), xmpp.NewPresence(j4, j4, xmpp.AvailableType)) time.Sleep(time.Millisecond * 150) // wait until processed... - stm1.SetBool(xep191RequestedContextKey, true) - stm2.SetBool(xep191RequestedContextKey, true) + stm1.SetValue(xep191RequestedContextKey, true) + stm2.SetValue(xep191RequestedContextKey, true) - storage.InsertOrUpdateRosterItem(&rostermodel.Item{ + _, _ = rosterRep.UpsertRosterItem(context.Background(), &rostermodel.Item{ Username: "ortuman", JID: "romeo@jackal.im", Subscription: "both", @@ -157,8 +165,9 @@ func TestXEP191_BlockAndUnblock(t *testing.T) { block := xmpp.NewElementNamespace("block", blockingCommandNamespace) iq.AppendElement(block) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm1.ReceiveElement() + require.Len(t, elem.Error().Elements().All(), 1) require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) item := xmpp.NewElementName("item") @@ -168,19 +177,20 @@ func TestXEP191_BlockAndUnblock(t *testing.T) { iq.AppendElement(block) // TEST BLOCK - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem = stm1.ReceiveElement() + require.Len(t, elem.Error().Elements().All(), 1) require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) // unavailable presence from *@jackal.im/jail - elem = stm1.ReceiveElement() + elem = stm4.ReceiveElement() require.Equal(t, "presence", elem.Name()) require.Equal(t, xmpp.UnavailableType, elem.Type()) - require.Equal(t, "romeo@jackal.im/jail", elem.From()) + require.Equal(t, "ortuman@jackal.im/balcony", elem.From()) // result IQ elem = stm1.ReceiveElement() @@ -197,18 +207,12 @@ func TestXEP191_BlockAndUnblock(t *testing.T) { item2 := block.Elements().Child("item") require.NotNil(t, item2) - // ortuman@jackal.im/yard - elem = stm2.ReceiveElement() - require.Equal(t, "presence", elem.Name()) - require.Equal(t, xmpp.UnavailableType, elem.Type()) - require.Equal(t, "romeo@jackal.im/jail", elem.From()) - elem = stm2.ReceiveElement() require.Equal(t, "iq", elem.Name()) require.Equal(t, xmpp.SetType, elem.Type()) // check storage - bl, _ := storage.FetchBlockListItems("ortuman") + bl, _ := blockListRep.FetchBlockListItems(context.Background(), "ortuman") require.NotNil(t, bl) require.Equal(t, 1, len(bl)) require.Equal(t, "jackal.im/jail", bl[0].JID) @@ -224,19 +228,20 @@ func TestXEP191_BlockAndUnblock(t *testing.T) { unblock.AppendElement(item) iq.AppendElement(unblock) - s.EnableMockedError() - x.ProcessIQ(iq) + memorystorage.EnableMockedError() + x.ProcessIQ(context.Background(), iq) elem = stm1.ReceiveElement() + require.Len(t, elem.Error().Elements().All(), 1) require.Equal(t, xmpp.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - s.DisableMockedError() + memorystorage.DisableMockedError() - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) // receive available presence from *@jackal.im/jail - elem = stm1.ReceiveElement() + elem = stm4.ReceiveElement() require.Equal(t, "presence", elem.Name()) require.Equal(t, xmpp.AvailableType, elem.Type()) - require.Equal(t, "romeo@jackal.im/jail", elem.From()) + require.Equal(t, "ortuman@jackal.im/balcony", elem.From()) // result IQ elem = stm1.ReceiveElement() @@ -254,13 +259,14 @@ func TestXEP191_BlockAndUnblock(t *testing.T) { require.NotNil(t, item2) // test full unblock - storage.InsertBlockListItems([]model.BlockListItem{{ + _ = blockListRep.InsertBlockListItem(context.Background(), &model.BlockListItem{ Username: "ortuman", JID: "hamlet@jackal.im/garden", - }, { + }) + _ = blockListRep.InsertBlockListItem(context.Background(), &model.BlockListItem{ Username: "ortuman", JID: "jabber.org", - }}) + }) iqID = uuid.New() iq = xmpp.NewIQType(iqID, xmpp.SetType) @@ -269,21 +275,24 @@ func TestXEP191_BlockAndUnblock(t *testing.T) { unblock = xmpp.NewElementNamespace("unblock", blockingCommandNamespace) iq.AppendElement(unblock) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) time.Sleep(time.Millisecond * 150) // wait until processed... - blItems, _ := storage.FetchBlockListItems("ortuman") + blItems, _ := blockListRep.FetchBlockListItems(context.Background(), "ortuman") require.Equal(t, 0, len(blItems)) } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) (router.Router, repository.Presences, repository.BlockList, repository.Roster) { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + + presencesRep := memorystorage.NewPresences() + blockListRep := memorystorage.NewBlockList() + rosterRep := memorystorage.NewRoster() + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), blockListRep), + nil, + ) + return r, presencesRep, blockListRep, rosterRep } diff --git a/module/xep0199/ping.go b/module/xep0199/ping.go index 49686a4ee..9543a2369 100644 --- a/module/xep0199/ping.go +++ b/module/xep0199/ping.go @@ -6,15 +6,17 @@ package xep0199 import ( + "context" "fmt" + "sync" "time" streamerror "github.com/ortuman/jackal/errors" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module/xep0030" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/pborman/uuid" @@ -22,6 +24,8 @@ import ( const pingNamespace = "urn:xmpp:ping" +const pingWriteTimeout = time.Second + // Config represents XMPP Ping module (XEP-0199) configuration. type Config struct { Send bool @@ -55,15 +59,16 @@ type ping struct { // Ping represents a ping server stream module. type Ping struct { - cfg *Config - router *router.Router - pings map[string]*ping - activePings map[string]*ping - runQueue *runqueue.RunQueue + cfg *Config + router router.Router + pings map[string]*ping + activePingsMu sync.RWMutex + activePings map[string]*ping + runQueue *runqueue.RunQueue } // New returns an ping IQ handler module. -func New(config *Config, disco *xep0030.DiscoInfo, router *router.Router) *Ping { +func New(config *Config, disco *xep0030.DiscoInfo, router router.Router) *Ping { p := &Ping{ cfg: config, router: router, @@ -83,15 +88,14 @@ func (x *Ping) MatchesIQ(iq *xmpp.IQ) bool { return x.isPongIQ(iq) || iq.Elements().ChildNamespace("ping", pingNamespace) != nil } -// ProcessIQ processes a ping IQ taking according actions -// over the associated stream. -func (x *Ping) ProcessIQ(iq *xmpp.IQ) { +// ProcessIQ processes a ping IQ taking according actions over the associated stream. +func (x *Ping) ProcessIQ(ctx context.Context, iq *xmpp.IQ) { x.runQueue.Run(func() { - stm := x.router.UserStream(iq.FromJID()) + stm := x.router.LocalStream(iq.FromJID().Node(), iq.FromJID().Resource()) if stm == nil { return } - x.processIQ(iq, stm) + x.processIQ(ctx, iq, stm) }) } @@ -118,27 +122,27 @@ func (x *Ping) Shutdown() error { return nil } -func (x *Ping) processIQ(iq *xmpp.IQ, stm stream.C2S) { +func (x *Ping) processIQ(ctx context.Context, iq *xmpp.IQ, stm stream.C2S) { if x.isPongIQ(iq) { x.handlePongIQ(iq, stm) return } toJid := iq.ToJID() if !toJid.IsServer() && toJid.Node() != stm.Username() { - stm.SendElement(iq.ForbiddenError()) + stm.SendElement(ctx, iq.ForbiddenError()) return } p := iq.Elements().ChildNamespace("ping", pingNamespace) if p == nil || p.Elements().Count() > 0 { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) return } log.Infof("received ping... id: %s", iq.ID()) if iq.IsGet() { log.Infof("sent pong... id: %s", iq.ID()) - stm.SendElement(iq.ResultIQ()) + stm.SendElement(ctx, iq.ResultIQ()) } else { - stm.SendElement(iq.BadRequestError()) + stm.SendElement(ctx, iq.BadRequestError()) } } @@ -179,7 +183,9 @@ func (x *Ping) schedulePingTimer(stm stream.C2S) { stm: stm, } pi.timer = time.AfterFunc(x.cfg.SendInterval, func() { - x.runQueue.Run(func() { x.sendPing(pi) }) + x.runQueue.Run(func() { + x.sendPing(pi) + }) }) x.pings[stm.JID().String()] = pi } @@ -195,6 +201,8 @@ func (x *Ping) handlePongIQ(iq *xmpp.IQ, stm stream.C2S) { } func (x *Ping) sendPing(pi *ping) { + ctx, _ := context.WithTimeout(context.Background(), pingWriteTimeout) + srvJID, _ := jid.New("", pi.stm.JID().Domain(), "", true) iq := xmpp.NewIQType(pi.identifier, xmpp.GetType) @@ -202,21 +210,29 @@ func (x *Ping) sendPing(pi *ping) { iq.SetToJID(pi.stm.JID()) iq.AppendElement(xmpp.NewElementNamespace("ping", pingNamespace)) - pi.stm.SendElement(iq) + pi.stm.SendElement(ctx, iq) log.Infof("sent ping... id: %s", pi.identifier) pi.timer = time.AfterFunc(x.cfg.SendInterval/3, func() { - x.runQueue.Run(func() { x.disconnectStream(pi) }) + x.runQueue.Run(func() { + x.disconnectStream(pi) + }) }) + x.activePingsMu.Lock() x.activePings[pi.identifier] = pi + x.activePingsMu.Unlock() } func (x *Ping) disconnectStream(pi *ping) { - pi.stm.Disconnect(streamerror.ErrConnectionTimeout) + ctx, _ := context.WithTimeout(context.Background(), pingWriteTimeout) + pi.stm.Disconnect(ctx, streamerror.ErrConnectionTimeout) } func (x *Ping) isPongIQ(iq *xmpp.IQ) bool { + x.activePingsMu.RLock() _, ok := x.activePings[iq.ID()] + x.activePingsMu.RUnlock() + return ok && (iq.IsResult() || iq.Type() == xmpp.ErrorType) } diff --git a/module/xep0199/ping_test.go b/module/xep0199/ping_test.go index b17b2fdb7..ad941ee06 100644 --- a/module/xep0199/ping_test.go +++ b/module/xep0199/ping_test.go @@ -6,11 +6,16 @@ package xep0199 import ( + "context" "crypto/tls" "testing" "time" + "github.com/ortuman/jackal/router/host" + + c2srouter "github.com/ortuman/jackal/c2s/router" "github.com/ortuman/jackal/router" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -22,7 +27,7 @@ func TestXEP0199_Matching(t *testing.T) { j, _ := jid.New("ortuman", "jackal.im", "balcony", true) x := New(&Config{}, nil, nil) - defer x.Shutdown() + defer func() { _ = x.Shutdown() }() // test MatchesIQ iqID := uuid.New() @@ -36,59 +41,55 @@ func TestXEP0199_Matching(t *testing.T) { } func TestXEP0199_ReceivePing(t *testing.T) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: "jackal.im", Certificate: tls.Certificate{}}}, - }) + r := setupTest() j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("juliet", "jackal.im", "garden", true) stm := stream.NewMockC2S(uuid.New(), j1) - r.Bind(stm) + r.Bind(context.Background(), stm) x := New(&Config{}, nil, r) - defer x.Shutdown() + defer func() { _ = x.Shutdown() }() iqID := uuid.New() iq := xmpp.NewIQType(iqID, xmpp.SetType) iq.SetFromJID(j1) iq.SetToJID(j2) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem := stm.ReceiveElement() require.Equal(t, xmpp.ErrForbidden.Error(), elem.Error().Elements().All()[0].Name()) iq.SetToJID(j1) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) ping := xmpp.NewElementNamespace("ping", pingNamespace) iq.AppendElement(ping) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, xmpp.ErrBadRequest.Error(), elem.Error().Elements().All()[0].Name()) iq.SetType(xmpp.GetType) - x.ProcessIQ(iq) + x.ProcessIQ(context.Background(), iq) elem = stm.ReceiveElement() require.Equal(t, iqID, elem.ID()) } func TestXEP0199_SendPing(t *testing.T) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: "jackal.im", Certificate: tls.Certificate{}}}, - }) + r := setupTest() j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) j2, _ := jid.New("", "jackal.im", "", true) stm := stream.NewMockC2S(uuid.New(), j1) - r.Bind(stm) + r.Bind(context.Background(), stm) x := New(&Config{Send: true, SendInterval: time.Second}, nil, r) - defer x.Shutdown() + defer func() { _ = x.Shutdown() }() x.SchedulePing(stm) @@ -102,7 +103,7 @@ func TestXEP0199_SendPing(t *testing.T) { pong := xmpp.NewIQType(elem.ID(), xmpp.ResultType) pong.SetFromJID(j1) pong.SetToJID(j2) - x.ProcessIQ(pong) + x.ProcessIQ(context.Background(), pong) x.SchedulePing(stm) // wait next ping... @@ -117,17 +118,15 @@ func TestXEP0199_SendPing(t *testing.T) { } func TestXEP0199_Disconnect(t *testing.T) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: "jackal.im", Certificate: tls.Certificate{}}}, - }) + r := setupTest() j1, _ := jid.New("ortuman", "jackal.im", "balcony", true) stm := stream.NewMockC2S(uuid.New(), j1) - r.Bind(stm) + r.Bind(context.Background(), stm) x := New(&Config{Send: true, SendInterval: time.Second}, nil, r) - defer x.Shutdown() + defer func() { _ = x.Shutdown() }() x.SchedulePing(stm) @@ -142,3 +141,13 @@ func TestXEP0199_Disconnect(t *testing.T) { require.NotNil(t, err) require.Equal(t, "connection-timeout", err.Error()) } + +func setupTest() router.Router { + hosts, _ := host.New([]host.Config{{Name: "jackal.im", Certificate: tls.Certificate{}}}) + r, _ := router.New( + hosts, + c2srouter.New(memorystorage.NewUser(), memorystorage.NewBlockList()), + nil, + ) + return r +} diff --git a/router/cluster_delegate.go b/router/cluster_delegate.go deleted file mode 100644 index b3c0eb354..000000000 --- a/router/cluster_delegate.go +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package router - -import ( - "github.com/ortuman/jackal/cluster" -) - -type clusterDelegate struct { - r *Router -} - -func (d *clusterDelegate) NotifyMessage(msg *cluster.Message) { d.r.handleNotifyMessage(msg) } -func (d *clusterDelegate) NodeJoined(node *cluster.Node) { d.r.handleNodeJoined(node) } -func (d *clusterDelegate) NodeUpdated(node *cluster.Node) {} -func (d *clusterDelegate) NodeLeft(node *cluster.Node) { d.r.handleNodeLeft(node) } diff --git a/router/config.go b/router/config.go deleted file mode 100644 index 736d64d8a..000000000 --- a/router/config.go +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package router - -import ( - "crypto/tls" - - "github.com/ortuman/jackal/util" - "github.com/pkg/errors" -) - -// Config represents a router configuration. -type Config struct { - Hosts []HostConfig -} - -type configProxy struct { - Hosts []HostConfig `yaml:"hosts"` -} - -// UnmarshalYAML satisfies Unmarshaler interface. -func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { - p := configProxy{} - if err := unmarshal(&p); err != nil { - return err - } - if len(p.Hosts) == 0 { - return errors.New("empty hosts array") - } - c.Hosts = p.Hosts - return nil -} - -type tlsConfig struct { - CertFile string `yaml:"cert_path"` - PrivKeyFile string `yaml:"privkey_path"` -} - -// HostConfig represents a host specific configuration. -type HostConfig struct { - Name string - Certificate tls.Certificate -} - -type hostConfigProxy struct { - Name string `yaml:"name"` - TLS tlsConfig `yaml:"tls"` -} - -// UnmarshalYAML satisfies Unmarshaler interface. -func (c *HostConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - p := hostConfigProxy{} - if err := unmarshal(&p); err != nil { - return err - } - c.Name = p.Name - cer, err := util.LoadCertificate(p.TLS.PrivKeyFile, p.TLS.CertFile, c.Name) - if err != nil { - return err - } - c.Certificate = cer - return nil -} diff --git a/router/config_test.go b/router/config_test.go deleted file mode 100644 index 68ea1099a..000000000 --- a/router/config_test.go +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package router - -import ( - "os" - "testing" - - "github.com/stretchr/testify/require" - yaml "gopkg.in/yaml.v2" -) - -func TestConfig_BadFormat(t *testing.T) { - s := Config{} - - err := yaml.Unmarshal([]byte("{["), &s) - require.NotNil(t, err) - - err = yaml.Unmarshal([]byte("{}"), &s) - require.NotNil(t, err) - - cfg := ` - hosts: - - name: jackal.im - tls: - privkey_path: "key.pem" - cert_path: "cert.pem" -` - err = yaml.Unmarshal([]byte(cfg), &s) - require.NotNil(t, err) -} - -func TestConfig_Valid(t *testing.T) { - defer os.RemoveAll("./.cert") - - s := Config{} - - cfg := ` - hosts: - - name: localhost - tls: - privkey_path: "" - cert_path: "" -` - err := yaml.Unmarshal([]byte(cfg), &s) - require.Nil(t, err) -} diff --git a/router/error.go b/router/error.go index 25bb6240b..9e8d59ddc 100644 --- a/router/error.go +++ b/router/error.go @@ -1,30 +1,25 @@ /* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. * See the LICENSE file for more information. */ package router -import "github.com/pkg/errors" +import "errors" var ( - // ErrNotExistingAccount will be returned by Route method - // if destination user does not exist. + // ErrNotExistingAccount will be returned by Route method if destination user does not exist. ErrNotExistingAccount = errors.New("router: account does not exist") - // ErrResourceNotFound will be returned by Route method - // if destination resource does not match any of user's available resources. + // ErrResourceNotFound will be returned by Route method if destination resource does not match any of user's available resources. ErrResourceNotFound = errors.New("router: resource not found") - // ErrNotAuthenticated will be returned by Route method if - // destination user is not available at this moment. + // ErrNotAuthenticated will be returned by Route method if destination user is not available at this moment. ErrNotAuthenticated = errors.New("router: user not authenticated") - // ErrBlockedJID will be returned by Route method if - // destination jid matches any of the user's blocked jid. + // ErrBlockedJID will be returned by Route method if destination jid matches any of the user's blocked jid. ErrBlockedJID = errors.New("router: destination jid is blocked") - // ErrFailedRemoteConnect will be returned by Route method if - // couldn't establish a connection to the remote server. + // ErrFailedRemoteConnect will be returned by Route method if couldn't establish a connection to the remote server. ErrFailedRemoteConnect = errors.New("router: failed remote connection") ) diff --git a/router/host/config.go b/router/host/config.go new file mode 100644 index 000000000..29218932c --- /dev/null +++ b/router/host/config.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package host + +import ( + "crypto/tls" + + utiltls "github.com/ortuman/jackal/util/tls" +) + +type TLSConfig struct { + CertFile string `yaml:"cert_path"` + PrivateKeyFile string `yaml:"privkey_path"` +} + +type Config struct { + Name string + Certificate tls.Certificate +} + +type configProxy struct { + Name string `yaml:"name"` + TLS TLSConfig `yaml:"tls"` +} + +// UnmarshalYAML satisfies Unmarshaler interface. +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + p := configProxy{} + if err := unmarshal(&p); err != nil { + return err + } + c.Name = p.Name + cer, err := utiltls.LoadCertificate(p.TLS.PrivateKeyFile, p.TLS.CertFile, c.Name) + if err != nil { + return err + } + c.Certificate = cer + return nil +} diff --git a/router/host/hosts.go b/router/host/hosts.go new file mode 100644 index 000000000..a50122f93 --- /dev/null +++ b/router/host/hosts.go @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package host + +import ( + "crypto/tls" + "sort" + + utiltls "github.com/ortuman/jackal/util/tls" +) + +const defaultDomain = "localhost" + +type Hosts struct { + defaultHostname string + hosts map[string]tls.Certificate + mucHostname string +} + +func New(hostsConfig []Config) (*Hosts, error) { + h := &Hosts{ + hosts: make(map[string]tls.Certificate), + } + if len(hostsConfig) > 0 { + for i, host := range hostsConfig { + if i == 0 { + h.defaultHostname = host.Name + } + h.hosts[host.Name] = host.Certificate + } + } else { + cer, err := utiltls.LoadCertificate("", "", defaultDomain) + if err != nil { + return nil, err + } + h.defaultHostname = defaultDomain + h.hosts[defaultDomain] = cer + } + return h, nil +} + +func (h *Hosts) DefaultHostName() string { + return h.defaultHostname +} + +func (h *Hosts) IsLocalHost(domain string) bool { + _, ok := h.hosts[domain] + return ok +} + +func (h *Hosts) HostNames() []string { + var ret []string + for n := range h.hosts { + ret = append(ret, n) + } + sort.Slice(ret, func(i, j int) bool { return ret[i] < ret[j] }) + return ret +} + +func (h *Hosts) Certificates() []tls.Certificate { + var certs []tls.Certificate + for _, cer := range h.hosts { + certs = append(certs, cer) + } + return certs +} + +func (h *Hosts) AddMucHostname(hostname string) { + h.mucHostname = hostname +} + +func (h *Hosts) IsConferenceHost(domain string) bool { + return domain == h.mucHostname +} diff --git a/router/router.go b/router/router.go index b09efbb40..b753f8f3a 100644 --- a/router/router.go +++ b/router/router.go @@ -1,573 +1,118 @@ /* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * Copyright (c) 2020 Miguel Ɓngel OrtuƱo. * See the LICENSE file for more information. */ package router import ( - "crypto/tls" - "runtime" - "sync" + "context" - "github.com/ortuman/jackal/cluster" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/router/host" "github.com/ortuman/jackal/stream" - "github.com/ortuman/jackal/util" - "github.com/ortuman/jackal/version" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) -const defaultDomain = "localhost" +type Router interface { -var bindMsgBatchSize = 1024 + // Hosts returns router hosts container. + Hosts() *host.Hosts -// OutS2SProvider provides a specific s2s outgoing connection for every single -// pair of (localdomain, remotedomain) values. -type OutS2SProvider interface { - GetOut(localDomain, remoteDomain string) (stream.S2SOut, error) -} - -// Cluster represents the generic cluster interface used by router type. -type Cluster interface { - // LocalNode returns local node name. - LocalNode() string - - C2SStream(jid *jid.JID, presence *xmpp.Presence, context map[string]interface{}, node string) *cluster.C2S - - SendMessageTo(node string, message *cluster.Message) - - BroadcastMessage(msg *cluster.Message) -} - -// Router represents an XMPP stanza router. -type Router struct { - mu sync.RWMutex - outS2SProvider OutS2SProvider - hosts map[string]tls.Certificate - streams map[string][]stream.C2S - cluster Cluster - localStreams map[string]stream.C2S - clusterStreams map[string]map[string]*cluster.C2S - - blockListsMu sync.RWMutex - blockLists map[string][]*jid.JID -} - -// New returns an new empty router instance. -func New(config *Config) (*Router, error) { - r := &Router{ - hosts: make(map[string]tls.Certificate), - blockLists: make(map[string][]*jid.JID), - streams: make(map[string][]stream.C2S), - localStreams: make(map[string]stream.C2S), - clusterStreams: make(map[string]map[string]*cluster.C2S), - } - if len(config.Hosts) > 0 { - for _, h := range config.Hosts { - r.hosts[h.Name] = h.Certificate - } - } else { - cer, err := util.LoadCertificate("", "", defaultDomain) - if err != nil { - return nil, err - } - r.hosts[defaultDomain] = cer - } - return r, nil -} - -// HostNames returns the list of all configured host names. -func (r *Router) HostNames() []string { - r.mu.RLock() - defer r.mu.RUnlock() - var ret []string - for n := range r.hosts { - ret = append(ret, n) - } - return ret -} - -// IsLocalHost returns true if domain is a local server domain. -func (r *Router) IsLocalHost(domain string) bool { - r.mu.RLock() - defer r.mu.RUnlock() - _, ok := r.hosts[domain] - return ok -} - -// Certificates returns an array of all configured domain certificates. -func (r *Router) Certificates() []tls.Certificate { - r.mu.RLock() - defer r.mu.RUnlock() - var certs []tls.Certificate - for _, cer := range r.hosts { - certs = append(certs, cer) - } - return certs -} - -// SetOutS2SProvider sets the s2s out provider to be used when routing stanzas remotely. -func (r *Router) SetOutS2SProvider(provider OutS2SProvider) { - r.mu.Lock() - defer r.mu.Unlock() - r.outS2SProvider = provider -} - -// SetCluster sets router cluster interface. -func (r *Router) SetCluster(cluster Cluster) { - r.mu.Lock() - defer r.mu.Unlock() - r.cluster = cluster -} - -// Cluster returns current router cluster. -func (r *Router) Cluster() Cluster { - r.mu.RLock() - defer r.mu.RUnlock() - return r.cluster -} + // Route routes a stanza applying server rules for handling XML stanzas. + // (https://xmpp.org/rfcs/rfc3921.html#rules) + Route(ctx context.Context, stanza xmpp.Stanza) error -// ClusterDelegate returns a router cluster delegate interface. -func (r *Router) ClusterDelegate() cluster.Delegate { - return &clusterDelegate{r: r} -} - -// Bind sets a c2s stream as bound. -// An error will be returned in case no assigned resource is found. -func (r *Router) Bind(stm stream.C2S) { - if len(stm.Resource()) == 0 { - return - } - // bind stream - r.mu.Lock() - defer r.mu.Unlock() - - r.bind(stm) - r.localStreams[stm.JID().String()] = stm - - log.Infof("bound c2s stream... (%s/%s)", stm.Username(), stm.Resource()) - - // broadcast cluster 'bind' message - if r.cluster != nil { - r.cluster.BroadcastMessage(&cluster.Message{ - Type: cluster.MsgBind, - Node: r.cluster.LocalNode(), - Payloads: []cluster.MessagePayload{{ - JID: stm.JID(), - Stanza: stm.Presence(), - Context: stm.Context(), - }}, - }) - } - return -} - -// Unbind unbinds a previously bound c2s stream. -// An error will be returned in case no assigned resource is found. -func (r *Router) Unbind(stmJID *jid.JID) { - if len(stmJID.Resource()) == 0 { - return - } - // unbind stream - r.mu.Lock() - defer r.mu.Unlock() - - if found := r.unbind(stmJID); !found { - return - } - delete(r.localStreams, stmJID.String()) - - log.Infof("unbound c2s stream... (%s/%s)", stmJID.Node(), stmJID.Resource()) - - // broadcast cluster 'unbind' message - if r.cluster != nil { - r.cluster.BroadcastMessage(&cluster.Message{ - Type: cluster.MsgUnbind, - Node: r.cluster.LocalNode(), - Payloads: []cluster.MessagePayload{{ - JID: stmJID, - }}, - }) - } -} - -// UserStreams returns the stream associated to a user jid. -func (r *Router) UserStream(j *jid.JID) stream.C2S { - r.mu.Lock() - defer r.mu.Unlock() - stms := r.streams[j.Node()] - - for _, stm := range stms { - if j.Matches(stm.JID(), jid.MatchesFull) { - return stm - } - } - return nil -} + // MustRoute forces stanza routing by ignoring user's blocking list. + MustRoute(ctx context.Context, stanza xmpp.Stanza) error -// UserStreams returns all streams associated to a user. -func (r *Router) UserStreams(username string) []stream.C2S { - r.mu.Lock() - defer r.mu.Unlock() - return r.streams[username] -} + // Bind sets a c2s stream as bound. + Bind(ctx context.Context, stm stream.C2S) -// IsBlockedJID returns whether or not the passed jid matches any of a user's blocking list jid. -func (r *Router) IsBlockedJID(jid *jid.JID, username string) bool { - bl := r.getBlockList(username) - for _, blkJID := range bl { - if r.jidMatchesBlockedJID(jid, blkJID) { - return true - } - } - return false -} - -// ReloadBlockList reloads in memory block list for a given user and starts applying it for future stanza routing. -func (r *Router) ReloadBlockList(username string) { - r.blockListsMu.Lock() - defer r.blockListsMu.Unlock() - - delete(r.blockLists, username) - log.Infof("block list reloaded... (username: %s)", username) -} - -// Route routes a stanza applying server rules for handling XML stanzas. -// (https://xmpp.org/rfcs/rfc3921.html#rules) -func (r *Router) Route(stanza xmpp.Stanza) error { - return r.route(stanza, false) -} + // Unbind unbinds a previously bound c2s stream. + Unbind(ctx context.Context, j *jid.JID) -// MustRoute routes a stanza applying server rules for handling XML stanzas -// ignoring blocking lists. -func (r *Router) MustRoute(stanza xmpp.Stanza) error { - return r.route(stanza, true) -} + // LocalStream returns the stream associated to a given username and resource. + LocalStream(username, resource string) stream.C2S -func (r *Router) jidMatchesBlockedJID(j, blockedJID *jid.JID) bool { - if blockedJID.IsFullWithUser() { - return j.Matches(blockedJID, jid.MatchesNode|jid.MatchesDomain|jid.MatchesResource) - } else if blockedJID.IsFullWithServer() { - return j.Matches(blockedJID, jid.MatchesDomain|jid.MatchesResource) - } else if blockedJID.IsBare() { - return j.Matches(blockedJID, jid.MatchesNode|jid.MatchesDomain) - } - return j.Matches(blockedJID, jid.MatchesDomain) + // LocalStreams returns all streams associated to a given username. + LocalStreams(username string) []stream.C2S } -func (r *Router) getBlockList(username string) []*jid.JID { - r.blockListsMu.RLock() - bl := r.blockLists[username] - r.blockListsMu.RUnlock() - if bl != nil { - return bl - } - blItems, err := storage.FetchBlockListItems(username) - if err != nil { - log.Error(err) - return nil - } - bl = []*jid.JID{} - for _, blItem := range blItems { - j, _ := jid.NewWithString(blItem.JID, true) - bl = append(bl, j) - } - r.blockListsMu.Lock() - r.blockLists[username] = bl - r.blockListsMu.Unlock() - return bl -} +type C2SRouter interface { + // Route routes a stanza applying server rules for handling XML stanzas. + // (https://xmpp.org/rfcs/rfc3921.html#rules) + Route(ctx context.Context, stanza xmpp.Stanza, validateStanza bool) error -func (r *Router) bind(stm stream.C2S) { - if usrStreams := r.streams[stm.Username()]; usrStreams != nil { - res := stm.Resource() - for _, usrStream := range usrStreams { - if usrStream.Resource() == res { - return // already bound - } - } - r.streams[stm.Username()] = append(usrStreams, stm) - } else { - r.streams[stm.Username()] = []stream.C2S{stm} - } -} + // Bind sets a c2s stream as bound. + Bind(stm stream.C2S) -func (r *Router) unbind(jid *jid.JID) bool { - found := false - if usrStreams := r.streams[jid.Node()]; usrStreams != nil { - res := jid.Resource() - for i := 0; i < len(usrStreams); i++ { - if res == usrStreams[i].Resource() { - usrStreams = append(usrStreams[:i], usrStreams[i+1:]...) - if len(usrStreams) > 0 { - r.streams[jid.Node()] = usrStreams - } else { - delete(r.streams, jid.Node()) - } - found = true - break - } - } - } - return found -} + // Unbind unbinds a previously bound c2s stream. + Unbind(username, resource string) -func (r *Router) route(element xmpp.Stanza, ignoreBlocking bool) error { - toJID := element.ToJID() - if !ignoreBlocking && !toJID.IsServer() { - if r.IsBlockedJID(element.FromJID(), toJID.Node()) { - return ErrBlockedJID - } - } - if !r.IsLocalHost(toJID.Domain()) { - return r.remoteRoute(element) - } - recipients := r.streams[toJID.Node()] - if len(recipients) == 0 { - exists, err := storage.UserExists(toJID.Node()) - if err != nil { - return err - } - if exists { - return ErrNotAuthenticated - } - return ErrNotExistingAccount - } - if toJID.IsFullWithUser() { - for _, stm := range recipients { - if stm.Resource() == toJID.Resource() { - stm.SendElement(element) - return nil - } - } - return ErrResourceNotFound - } - switch element.(type) { - case *xmpp.Message: - // send to highest priority stream - stm := recipients[0] - var highestPriority int8 - if p := stm.Presence(); p != nil { - highestPriority = p.Priority() - } - for i := 1; i < len(recipients); i++ { - rcp := recipients[i] - if p := rcp.Presence(); p != nil && p.Priority() > highestPriority { - stm = rcp - highestPriority = p.Priority() - } - } - stm.SendElement(element) + // Stream returns the stream associated to a given username and resource. + Stream(username, resource string) stream.C2S - default: - // broadcast toJID all streams - for _, stm := range recipients { - stm.SendElement(element) - } - } - return nil + // Streams returns all streams associated to a given username. + Streams(username string) []stream.C2S } -func (r *Router) remoteRoute(elem xmpp.Stanza) error { - if r.outS2SProvider == nil { - return ErrFailedRemoteConnect - } - localDomain := elem.FromJID().Domain() - remoteDomain := elem.ToJID().Domain() - - out, err := r.outS2SProvider.GetOut(localDomain, remoteDomain) - if err != nil { - log.Error(err) - return ErrFailedRemoteConnect - } - out.SendElement(elem) - return nil +type S2SRouter interface { + // Route routes a stanza applying server rules for handling XML stanzas. + // (https://xmpp.org/rfcs/rfc3921.html#rules) + Route(ctx context.Context, stanza xmpp.Stanza, localDomain string) error } -func (r *Router) handleNotifyMessage(msg *cluster.Message) { - switch msg.Type { - case cluster.MsgBatchBind, cluster.MsgBind: - r.processBindMessage(msg) - case cluster.MsgUnbind: - r.processUnbindMessage(msg) - case cluster.MsgUpdatePresence: - r.processUpdatePresenceMessage(msg) - case cluster.MsgUpdateContext: - r.processUpdateContext(msg) - case cluster.MsgRouteStanza: - r.processRouteStanzaMessage(msg) - } +type router struct { + hosts *host.Hosts + c2s C2SRouter + s2s S2SRouter } -func (r *Router) handleNodeJoined(node *cluster.Node) { - r.mu.RLock() - defer r.mu.RUnlock() - if r.cluster == nil { - return - } - - if node.Metadata.Version != version.ApplicationVersion.String() { - log.Warnf("incompatible server version: %s (node: %s)", node.Metadata.Version, node.Name) - return - } - if node.Metadata.GoVersion != runtime.Version() { - log.Warnf("incompatible runtime version: %s (node: %s)", node.Metadata.GoVersion, node.Name) - return - } - // send local JIDs in batches to the recently joined node - i := 0 - var payloads []cluster.MessagePayload - for _, stm := range r.localStreams { - payloads = append(payloads, cluster.MessagePayload{ - JID: stm.JID(), - Stanza: stm.Presence(), - Context: stm.Context(), - }) - i++ - if i == bindMsgBatchSize { - r.cluster.SendMessageTo(node.Name, &cluster.Message{ - Type: cluster.MsgBatchBind, - Node: r.cluster.LocalNode(), - Payloads: payloads, - }) - payloads = nil - i = 0 - } - } - // send remaining ones... - if len(payloads) > 0 { - r.cluster.SendMessageTo(node.Name, &cluster.Message{ - Type: cluster.MsgBatchBind, - Node: r.cluster.LocalNode(), - Payloads: payloads, - }) +func New(hosts *host.Hosts, c2sRouter C2SRouter, s2sRouter S2SRouter) (Router, error) { + r := &router{ + hosts: hosts, + c2s: c2sRouter, + s2s: s2sRouter, } + return r, nil } -func (r *Router) handleNodeLeft(node *cluster.Node) { - r.mu.Lock() - defer r.mu.Unlock() - - // unbind node streams - if streams := r.clusterStreams[node.Name]; streams != nil { - for _, stm := range streams { - r.unbind(stm.JID()) - } - } - delete(r.clusterStreams, node.Name) +func (r *router) Hosts() *host.Hosts { + return r.hosts } -func (r *Router) processBindMessage(msg *cluster.Message) { - r.mu.Lock() - defer r.mu.Unlock() - if r.cluster == nil { - return - } - - for _, p := range msg.Payloads { - j := p.JID - presence, ok := p.Stanza.(*xmpp.Presence) - if !ok { - continue - } - log.Debugf("bound cluster c2s: %s", j.String()) - - stm := r.cluster.C2SStream(j, presence, p.Context, msg.Node) - r.bind(stm) - r.registerClusterC2S(stm, msg.Node) - } +func (r *router) MustRoute(ctx context.Context, stanza xmpp.Stanza) error { + return r.route(ctx, stanza, false) } -func (r *Router) processUnbindMessage(msg *cluster.Message) { - r.mu.Lock() - defer r.mu.Unlock() - if r.cluster == nil { - return - } - j := msg.Payloads[0].JID - - log.Debugf("unbound cluster c2s: %s", j.String()) - r.unbind(j) - r.unregisterClusterC2S(j, msg.Node) +func (r *router) Route(ctx context.Context, stanza xmpp.Stanza) error { + return r.route(ctx, stanza, true) } -func (r *Router) processUpdateContext(msg *cluster.Message) { - r.mu.RLock() - defer r.mu.RUnlock() - if r.cluster == nil { - return - } - j := msg.Payloads[0].JID - context := msg.Payloads[0].Context - - log.Debugf("updated cluster c2s context: %s\n%v", j.String(), context) - - var stm *cluster.C2S - if streams := r.clusterStreams[msg.Node]; streams != nil { - stm = streams[j.String()] - } - if stm == nil { - return - } - stm.UpdateContext(context) +func (r *router) Bind(ctx context.Context, stm stream.C2S) { + r.c2s.Bind(stm) } -func (r *Router) processUpdatePresenceMessage(msg *cluster.Message) { - r.mu.RLock() - defer r.mu.RUnlock() - if r.cluster == nil { - return - } - j := msg.Payloads[0].JID - stanza := msg.Payloads[0].Stanza - - presence, ok := stanza.(*xmpp.Presence) - if !ok { - return - } - log.Debugf("updated cluster c2s presence: %s\n%v", j.String(), presence) - - var stm *cluster.C2S - if streams := r.clusterStreams[msg.Node]; streams != nil { - stm = streams[j.String()] - } - if stm == nil { - return - } - stm.SetPresence(presence) +func (r *router) Unbind(ctx context.Context, j *jid.JID) { + r.c2s.Unbind(j.Node(), j.Resource()) } -func (r *Router) processRouteStanzaMessage(msg *cluster.Message) { - r.mu.RLock() - defer r.mu.RUnlock() - if r.cluster == nil { - return - } - j := msg.Payloads[0].JID - stanza := msg.Payloads[0].Stanza - - log.Debugf("routing cluster stanza: %s\n%v", j.String(), stanza) - _ = r.route(stanza, false) +func (r *router) LocalStreams(username string) []stream.C2S { + return r.c2s.Streams(username) } -func (r *Router) registerClusterC2S(stm *cluster.C2S, node string) { - if streams := r.clusterStreams[node]; streams != nil { - streams[stm.JID().String()] = stm - } else { - r.clusterStreams[node] = map[string]*cluster.C2S{ - stm.JID().String(): stm, - } - } +func (r *router) LocalStream(username, resource string) stream.C2S { + return r.c2s.Stream(username, resource) } -func (r *Router) unregisterClusterC2S(jid *jid.JID, node string) { - if streams := r.clusterStreams[node]; streams != nil { - delete(streams, jid.String()) - if len(streams) == 0 { - delete(r.clusterStreams, node) +func (r *router) route(ctx context.Context, stanza xmpp.Stanza, validateStanza bool) error { + toJID := stanza.ToJID() + if !r.hosts.IsLocalHost(toJID.Domain()) && !r.hosts.IsConferenceHost(toJID.Domain()) { + if r.s2s == nil { + return ErrFailedRemoteConnect } + return r.s2s.Route(ctx, stanza, r.hosts.DefaultHostName()) } + return r.c2s.Route(ctx, stanza, validateStanza) } diff --git a/router/router_test.go b/router/router_test.go deleted file mode 100644 index 118117b98..000000000 --- a/router/router_test.go +++ /dev/null @@ -1,481 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package router - -import ( - "crypto/tls" - "os" - "runtime" - "testing" - "time" - - "github.com/ortuman/jackal/cluster" - "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" - "github.com/ortuman/jackal/stream" - "github.com/ortuman/jackal/version" - "github.com/ortuman/jackal/xmpp" - "github.com/ortuman/jackal/xmpp/jid" - "github.com/pborman/uuid" - "github.com/stretchr/testify/require" -) - -const routerOpTimeout = time.Millisecond * 250 - -type fakeClusterDelegate struct { - cluster cluster.Cluster - sendCh chan *cluster.Message - sendMessageToCalls int - broadcastMessageCalls int -} - -func (d *fakeClusterDelegate) LocalNode() string { - return "node1" -} - -func (d *fakeClusterDelegate) C2SStream(jid *jid.JID, presence *xmpp.Presence, context map[string]interface{}, node string) *cluster.C2S { - return d.cluster.C2SStream(jid, presence, context, node) -} - -func (d *fakeClusterDelegate) SendMessageTo(node string, message *cluster.Message) { - if d.sendCh != nil { - d.sendCh <- message - } - d.sendMessageToCalls++ -} - -func (d *fakeClusterDelegate) BroadcastMessage(msg *cluster.Message) { - d.broadcastMessageCalls++ -} - -type fakeS2SOut struct { - elems []xmpp.XElement -} - -func (f *fakeS2SOut) ID() string { return uuid.New() } -func (f *fakeS2SOut) SendElement(elem xmpp.XElement) { f.elems = append(f.elems, elem) } -func (f *fakeS2SOut) Disconnect(err error) {} - -type fakeOutS2SProvider struct{ s2sOut *fakeS2SOut } - -func (f *fakeOutS2SProvider) GetOut(localDomain, remoteDomain string) (stream.S2SOut, error) { - return f.s2sOut, nil -} - -func TestRouter_EmptyConfig(t *testing.T) { - defer os.RemoveAll("./.cert") - - r, _ := New(&Config{}) - require.True(t, r.IsLocalHost("localhost")) - require.Equal(t, 1, len(r.HostNames())) - require.Equal(t, 1, len(r.Certificates())) -} - -func TestRouter_SetCluster(t *testing.T) { - r, _, shutdown := setupTest() - defer shutdown() - - var del fakeClusterDelegate - r.SetCluster(&del) - require.Equal(t, &del, r.Cluster()) -} - -func TestRouter_ClusterDelegate(t *testing.T) { - r, _, shutdown := setupTest() - defer shutdown() - - del, ok := r.ClusterDelegate().(cluster.Delegate) - require.True(t, ok) - require.NotNil(t, del) -} - -func TestRouter_Binding(t *testing.T) { - r, _, shutdown := setupTest() - defer shutdown() - - var del fakeClusterDelegate - r.SetCluster(&del) - - j1, _ := jid.NewWithString("ortuman@jackal.im/balcony", false) - j2, _ := jid.NewWithString("ortuman@jackal.im/garden", false) - j3, _ := jid.NewWithString("hamlet@jackal.im/balcony", false) - j4, _ := jid.NewWithString("romeo@jackal.im/balcony", false) - j5, _ := jid.NewWithString("juliet@jackal.im/garden", false) - j6, _ := jid.NewWithString("juliet@jackal.im", false) // empty resource - j7, _ := jid.NewWithString("juliet@jackal.im/yard", false) - stm1 := stream.NewMockC2S(uuid.New(), j1) - stm2 := stream.NewMockC2S(uuid.New(), j2) - stm3 := stream.NewMockC2S(uuid.New(), j3) - stm4 := stream.NewMockC2S(uuid.New(), j4) - stm5 := stream.NewMockC2S(uuid.New(), j5) - stm6 := stream.NewMockC2S(uuid.New(), j6) - - r.Bind(stm1) - r.Bind(stm2) - r.Bind(stm3) - r.Bind(stm4) - r.Bind(stm5) - r.Bind(stm6) - - require.Equal(t, 5, del.broadcastMessageCalls) - - require.Equal(t, 2, len(r.UserStreams("ortuman"))) - require.Equal(t, 1, len(r.UserStreams("hamlet"))) - require.Equal(t, 1, len(r.UserStreams("romeo"))) - require.Equal(t, 1, len(r.UserStreams("juliet"))) - - r.Unbind(j7) - r.Unbind(j6) - r.Unbind(j5) - r.Unbind(j4) - r.Unbind(j3) - r.Unbind(j2) - r.Unbind(j1) - - require.Equal(t, 10, del.broadcastMessageCalls) - - require.Equal(t, 0, len(r.UserStreams("ortuman"))) - require.Equal(t, 0, len(r.UserStreams("hamlet"))) - require.Equal(t, 0, len(r.UserStreams("romeo"))) - require.Equal(t, 0, len(r.UserStreams("juliet"))) -} - -func TestRouter_Routing(t *testing.T) { - outS2S := fakeS2SOut{} - s2sOutProvider := fakeOutS2SProvider{s2sOut: &outS2S} - - r, s, shutdown := setupTest() - defer shutdown() - - r.SetOutS2SProvider(&s2sOutProvider) - - j1, _ := jid.NewWithString("ortuman@jackal.im/balcony", false) - j2, _ := jid.NewWithString("ortuman@jackal.im/garden", false) - j3, _ := jid.NewWithString("hamlet@jackal.im/balcony", false) - j4, _ := jid.NewWithString("hamlet@jackal.im/garden", false) - j5, _ := jid.NewWithString("hamlet@jackal.im", false) - j6, _ := jid.NewWithString("juliet@example.org/garden", false) - stm1 := stream.NewMockC2S(uuid.New(), j1) - stm2 := stream.NewMockC2S(uuid.New(), j2) - stm3 := stream.NewMockC2S(uuid.New(), j3) - - r.Bind(stm1) - r.Bind(stm2) - - iqID := uuid.New() - iq := xmpp.NewIQType(iqID, xmpp.SetType) - iq.SetFromJID(j1) - iq.SetToJID(j6) - - // remote routing - require.Nil(t, r.Route(iq)) - require.Equal(t, 1, len(outS2S.elems)) - - iq.SetToJID(j3) - require.Equal(t, ErrNotExistingAccount, r.Route(iq)) - - s.EnableMockedError() - require.Equal(t, memstorage.ErrMockedError, r.Route(iq)) - s.DisableMockedError() - - _ = storage.InsertOrUpdateUser(&model.User{Username: "hamlet", Password: ""}) - require.Equal(t, ErrNotAuthenticated, r.Route(iq)) - - stm4 := stream.NewMockC2S(uuid.New(), j4) - r.Bind(stm4) - require.Equal(t, ErrResourceNotFound, r.Route(iq)) - - r.Bind(stm3) - require.Nil(t, r.Route(iq)) - elem := stm3.ReceiveElement() - require.Equal(t, iqID, elem.ID()) - - // broadcast stanza - iq.SetToJID(j5) - require.Nil(t, r.Route(iq)) - elem = stm3.ReceiveElement() - require.Equal(t, iqID, elem.ID()) - elem = stm4.ReceiveElement() - require.Equal(t, iqID, elem.ID()) - - // send clusterMessage to highest priority - p1 := xmpp.NewElementName("presence") - p1.SetFrom(j3.String()) - p1.SetTo(j3.String()) - p1.SetType(xmpp.AvailableType) - pr1 := xmpp.NewElementName("priority") - pr1.SetText("2") - p1.AppendElement(pr1) - presence1, _ := xmpp.NewPresenceFromElement(p1, j3, j3) - stm3.SetPresence(presence1) - - p2 := xmpp.NewElementName("presence") - p2.SetFrom(j4.String()) - p2.SetTo(j4.String()) - p2.SetType(xmpp.AvailableType) - pr2 := xmpp.NewElementName("priority") - pr2.SetText("1") - p2.AppendElement(pr2) - presence2, _ := xmpp.NewPresenceFromElement(p2, j4, j4) - stm4.SetPresence(presence2) - - msgID := uuid.New() - msg := xmpp.NewMessageType(msgID, xmpp.ChatType) - msg.SetToJID(j5) - require.Nil(t, r.Route(msg)) - elem = stm3.ReceiveElement() - require.Equal(t, msgID, elem.ID()) -} - -func TestRouter_BlockedJID(t *testing.T) { - r, _, shutdown := setupTest() - defer shutdown() - - j1, _ := jid.NewWithString("ortuman@jackal.im/balcony", false) - j2, _ := jid.NewWithString("hamlet@jackal.im/balcony", false) - j3, _ := jid.NewWithString("hamlet@jackal.im/garden", false) - j4, _ := jid.NewWithString("juliet@jackal.im/garden", false) - stm1 := stream.NewMockC2S(uuid.New(), j1) - stm2 := stream.NewMockC2S(uuid.New(), j2) - - r.Bind(stm1) - r.Bind(stm2) - - // node + domain + resource - bl1 := []model.BlockListItem{{ - Username: "ortuman", - JID: "hamlet@jackal.im/garden", - }} - _ = storage.InsertBlockListItems(bl1) - require.False(t, r.IsBlockedJID(j2, "ortuman")) - require.True(t, r.IsBlockedJID(j3, "ortuman")) - - _ = storage.DeleteBlockListItems(bl1) - - // node + domain - bl2 := []model.BlockListItem{{ - Username: "ortuman", - JID: "hamlet@jackal.im", - }} - _ = storage.InsertBlockListItems(bl2) - r.ReloadBlockList("ortuman") - - require.True(t, r.IsBlockedJID(j2, "ortuman")) - require.True(t, r.IsBlockedJID(j3, "ortuman")) - require.False(t, r.IsBlockedJID(j4, "ortuman")) - - _ = storage.DeleteBlockListItems(bl2) - - // domain + resource - bl3 := []model.BlockListItem{{ - Username: "ortuman", - JID: "jackal.im/balcony", - }} - _ = storage.InsertBlockListItems(bl3) - r.ReloadBlockList("ortuman") - - require.True(t, r.IsBlockedJID(j2, "ortuman")) - require.False(t, r.IsBlockedJID(j3, "ortuman")) - require.False(t, r.IsBlockedJID(j4, "ortuman")) - - _ = storage.DeleteBlockListItems(bl3) - - // domain - bl4 := []model.BlockListItem{{ - Username: "ortuman", - JID: "jackal.im", - }} - _ = storage.InsertBlockListItems(bl4) - r.ReloadBlockList("ortuman") - - require.True(t, r.IsBlockedJID(j2, "ortuman")) - require.True(t, r.IsBlockedJID(j3, "ortuman")) - require.True(t, r.IsBlockedJID(j4, "ortuman")) - - _ = storage.DeleteBlockListItems(bl4) - - // test blocked routing - iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) - iq.SetFromJID(j2) - iq.SetToJID(j1) - require.Equal(t, ErrBlockedJID, r.Route(iq)) -} - -func TestRouter_Cluster(t *testing.T) { - r, _, shutdown := setupTest() - defer shutdown() - - var del fakeClusterDelegate - del.sendCh = make(chan *cluster.Message, 2) - r.SetCluster(&del) - - j1, _ := jid.NewWithString("ortuman@jackal.im/balcony", false) - j2, _ := jid.NewWithString("ortuman@jackal.im/garden", false) - j3, _ := jid.NewWithString("hamlet@jackal.im/balcony", false) - stm1 := stream.NewMockC2S(uuid.New(), j1) - stm2 := stream.NewMockC2S(uuid.New(), j2) - stm3 := stream.NewMockC2S(uuid.New(), j3) - - r.Bind(stm1) - r.Bind(stm2) - r.Bind(stm3) - - node := &cluster.Node{ - Name: "node2", - Metadata: cluster.Metadata{ - Version: version.ApplicationVersion.String(), - GoVersion: runtime.Version(), - }, - } - bindMsgBatchSize = 2 - - r.handleNodeJoined(node) - - // expecting 2 batches - for i := 0; i < 2; i++ { - select { - case <-del.sendCh: - break - case <-time.After(routerOpTimeout): - require.Fail(t, "handle cluster join timeout") - } - } - require.Equal(t, 2, del.sendMessageToCalls) - - // try to join with incompatible version - r.handleNodeJoined(&cluster.Node{ - Name: "node3", - Metadata: cluster.Metadata{ - Version: version.ApplicationVersion.String(), - GoVersion: "v0.1", - }, - }) - r.handleNodeJoined(&cluster.Node{ - Name: "node4", - Metadata: cluster.Metadata{ - Version: "v0.0.0.1.rc2", - GoVersion: runtime.Version(), - }, - }) - require.Equal(t, 2, del.sendMessageToCalls) // nothing happened - - r.SetCluster(nil) - r.handleNodeJoined(node) - require.Equal(t, 2, del.sendMessageToCalls) // nothing happened - - // process bind message - r.SetCluster(&del) - - j4, _ := jid.NewWithString("noelia@jackal.im/balcony", true) - j5, _ := jid.NewWithString("noelia@jackal.im/yard", true) - - r.handleNotifyMessage(&cluster.Message{ - Type: cluster.MsgBind, - Node: "node2", - Payloads: []cluster.MessagePayload{{ - JID: j4, - Stanza: xmpp.NewPresence(j4, j4, xmpp.AvailableType), - Context: map[string]interface{}{}, - }}, - }) - r.handleNotifyMessage(&cluster.Message{ - Type: cluster.MsgBind, - Node: "node2", - Payloads: []cluster.MessagePayload{{ - JID: j5, - Stanza: xmpp.NewPresence(j5, j5, xmpp.AvailableType), - Context: map[string]interface{}{}, - }}, - }) - r.mu.RLock() - require.Equal(t, 2, len(r.clusterStreams["node2"])) - r.mu.RUnlock() - - r.handleNotifyMessage(&cluster.Message{ - Type: cluster.MsgUnbind, - Node: "node2", - Payloads: []cluster.MessagePayload{{ - JID: j5, - Stanza: xmpp.NewPresence(j5, j5, xmpp.AvailableType), - }}, - }) - r.mu.RLock() - require.Equal(t, 1, len(r.clusterStreams["node2"])) - r.mu.RUnlock() - - // update cluster stream presence - p := xmpp.NewPresence(j4, j4, xmpp.UnavailableType) - r.handleNotifyMessage(&cluster.Message{ - Type: cluster.MsgUpdatePresence, - Node: "node2", - Payloads: []cluster.MessagePayload{{ - JID: j4, - Stanza: p, - }}, - }) - r.mu.RLock() - stm := r.clusterStreams["node2"][j4.String()] - require.NotNil(t, stm) - require.Equal(t, stm.Presence(), p) - r.mu.RUnlock() - - // update cluster stream context - r.handleNotifyMessage(&cluster.Message{ - Type: cluster.MsgUpdateContext, - Node: "node2", - Payloads: []cluster.MessagePayload{{ - JID: j4, - Context: map[string]interface{}{ - "var": "foo", - }, - }}, - }) - r.mu.RLock() - stm = r.clusterStreams["node2"][j4.String()] - require.NotNil(t, stm) - require.Equal(t, "foo", stm.GetString("var")) - r.mu.RUnlock() - - r.handleNodeLeft(&cluster.Node{ - Name: "node2", - Metadata: cluster.Metadata{ - Version: version.ApplicationVersion.String(), - GoVersion: runtime.Version(), - }, - }) - r.mu.RLock() - require.Equal(t, 0, len(r.clusterStreams["node2"])) - r.mu.RUnlock() - - // test cluster stanza routing - iq := xmpp.NewIQType(uuid.New(), xmpp.GetType) - iq.SetFromJID(j4) - iq.SetToJID(j3) - - r.handleNotifyMessage(&cluster.Message{ - Type: cluster.MsgRouteStanza, - Node: "node2", - Payloads: []cluster.MessagePayload{{ - JID: j4, - Stanza: iq, - }}, - }) - elem := stm3.ReceiveElement() - require.NotNil(t, elem) - require.Equal(t, elem, iq) -} - -func setupTest() (*Router, *memstorage.Storage, func()) { - r, _ := New(&Config{ - Hosts: []HostConfig{{Name: "jackal.im", Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } -} diff --git a/s2s/config.go b/s2s/config.go index 0f8bfcb0f..bfd78a22a 100644 --- a/s2s/config.go +++ b/s2s/config.go @@ -11,10 +11,7 @@ import ( "time" "github.com/netsec-ethz/scion-apps/lib/scionutil" - "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/stream" - "github.com/ortuman/jackal/transport" - "github.com/ortuman/jackal/xmpp" "github.com/pkg/errors" "github.com/scionproto/scion/go/lib/sciond" "github.com/scionproto/scion/go/lib/sock/reliable" @@ -26,6 +23,7 @@ const ( defaultTransportKeepAlive = time.Duration(10) * time.Minute defaultDialTimeout = time.Duration(15) * time.Second defaultConnectTimeout = time.Duration(5) * time.Second + defaultTimeout = time.Duration(20) * time.Second defaultMaxStanzaSize = 131072 ) @@ -33,13 +31,11 @@ const ( type TransportConfig struct { BindAddress string Port int - KeepAlive time.Duration } type transportConfigProxy struct { BindAddress string `yaml:"bind_addr"` Port int `yaml:"port"` - KeepAlive int `yaml:"keep_alive"` } // UnmarshalYAML satisfies Unmarshaler interface. @@ -53,11 +49,6 @@ func (c *TransportConfig) UnmarshalYAML(unmarshal func(interface{}) error) error if c.Port == 0 { c.Port = defaultTransportPort } - if p.KeepAlive > 0 { - c.KeepAlive = time.Duration(p.KeepAlive) * time.Second - } else { - c.KeepAlive = defaultTransportKeepAlive - } return nil } @@ -137,6 +128,8 @@ type Config struct { ID string DialTimeout time.Duration ConnectTimeout time.Duration + KeepAlive time.Duration + Timeout time.Duration DialbackSecret string MaxStanzaSize int Transport TransportConfig @@ -147,6 +140,8 @@ type configProxy struct { ID string `yaml:"id"` DialTimeout int `yaml:"dial_timeout"` ConnectTimeout int `yaml:"connect_timeout"` + KeepAlive int `yaml:"keep_alive"` + Timeout int `yaml:"timeout"` DialbackSecret string `yaml:"dialback_secret"` MaxStanzaSize int `yaml:"max_stanza_size"` Transport TransportConfig `yaml:"transport"` @@ -172,6 +167,15 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { if c.ConnectTimeout == 0 { c.ConnectTimeout = defaultConnectTimeout } + if p.KeepAlive > 0 { + c.KeepAlive = time.Duration(p.KeepAlive) * time.Second + } else { + c.KeepAlive = defaultTransportKeepAlive + } + c.Timeout = time.Duration(p.Timeout) * time.Second + if c.Timeout == 0 { + c.Timeout = defaultTimeout + } c.Transport = p.Transport c.MaxStanzaSize = p.MaxStanzaSize if c.MaxStanzaSize == 0 { @@ -181,17 +185,23 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } -type streamConfig struct { - modConfig *module.Config - keyGen *keyGen - localDomain string - remoteDomain string - connectTimeout time.Duration - tls *tls.Config - transport transport.Transport - maxStanzaSize int - dbVerify xmpp.XElement - dialer *dialer - onInDisconnect func(s stream.S2SIn) - onOutDisconnect func(s stream.S2SOut) +type inConfig struct { + keyGen *keyGen + connectTimeout time.Duration + timeout time.Duration + keepAlive time.Duration + tls *tls.Config + maxStanzaSize int + onDisconnect func(s stream.S2SIn) +} + +type outConfig struct { + keyGen *keyGen + localDomain string + remoteDomain string + timeout time.Duration + keepAlive time.Duration + tls *tls.Config + maxStanzaSize int + scion *ScionConfig } diff --git a/s2s/config_test.go b/s2s/config_test.go index 132dd9dfe..01b6c5c2e 100644 --- a/s2s/config_test.go +++ b/s2s/config_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/stretchr/testify/require" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) func TestTransportConfig(t *testing.T) { @@ -28,18 +28,15 @@ bind_addr: 0.0.0.0 require.Nil(t, err) require.Equal(t, "0.0.0.0", trCfg.BindAddress) require.Equal(t, 5269, trCfg.Port) - require.Equal(t, time.Duration(600)*time.Second, trCfg.KeepAlive) rawCfg = ` bind_addr: 127.0.0.1 port: 5999 -keep_alive: 200 ` err = yaml.Unmarshal([]byte(rawCfg), &trCfg) require.Nil(t, err) require.Equal(t, "127.0.0.1", trCfg.BindAddress) require.Equal(t, 5999, trCfg.Port) - require.Equal(t, time.Duration(200)*time.Second, trCfg.KeepAlive) } func TestConfig(t *testing.T) { diff --git a/s2s/dial.go b/s2s/dial.go deleted file mode 100644 index 6229f0530..000000000 --- a/s2s/dial.go +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package s2s - -import ( - "crypto/tls" - "net" - "strconv" - "strings" - "time" - - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/transport" - - "github.com/lucas-clemente/quic-go" - "github.com/netsec-ethz/scion-apps/lib/scionutil" - "github.com/scionproto/scion/go/lib/sciond" - "github.com/scionproto/scion/go/lib/snet" - "github.com/scionproto/scion/go/lib/snet/squic" - "github.com/scionproto/scion/go/lib/sock/reliable" - - libaddr "github.com/scionproto/scion/go/lib/addr" -) - -type dialer struct { - cfg *Config - router *router.Router - srvResolve func(service, proto, name string) (cname string, addrs []*net.SRV, err error) - dialTimeout func(network, address string, timeout time.Duration) (net.Conn, error) -} - -func newDialer(cfg *Config, router *router.Router) *dialer { - return &dialer{cfg: cfg, router: router, srvResolve: net.LookupSRV, dialTimeout: net.DialTimeout} -} - -func (d *dialer) dialQUIC(remote *snet.Addr, localDomain, remoteDomain string) (*streamConfig, error) { - var local *snet.Addr - var err error - if d.cfg.Scion.Address == "localhost" { - local, err = scionutil.GetLocalhost() - } else { - local, err = snet.AddrFromString(d.cfg.Scion.Address) - } - if err != nil { - return nil, err - } - - sciondPath := sciond.GetDefaultSCIONDPath(nil) - dispatcherPath := d.cfg.Scion.Dispatcher - snet.Init(local.IA, sciondPath, reliable.NewDispatcherService(dispatcherPath)) - quicConfig := &quic.Config{ - KeepAlive: true, - } - sess, err := squic.DialSCION(nil, local, remote, quicConfig) - if err != nil { - return nil, err - } - biStream, err := sess.OpenStreamSync() - if err != nil { - log.Infof("Couldn't open a new QUIC Stream") - } - - tr := transport.NewQUICSocketTransport(sess, biStream, - d.cfg.Transport.KeepAlive) - return &streamConfig{ - keyGen: &keyGen{secret: d.cfg.DialbackSecret}, - localDomain: localDomain, - remoteDomain: remoteDomain, - transport: tr, - maxStanzaSize: d.cfg.MaxStanzaSize, - }, nil -} - -func (d *dialer) dialTCP(localDomain, remoteDomain string) (*streamConfig, error) { - _, addrs, err := d.srvResolve("xmpp-server", "tcp", remoteDomain) - if err != nil { - log.Warnf("srv lookup error: %v", err) - } - var target string - - if err != nil || len(addrs) == 1 && addrs[0].Target == "." { - target = remoteDomain + ":5269" - } else { - target = strings.TrimSuffix(addrs[0].Target, ".") + ":" + strconv.Itoa(int(addrs[0].Port)) - } - conn, err := d.dialTimeout("tcp", target, d.cfg.DialTimeout) - if err != nil { - return nil, err - } - tlsConfig := &tls.Config{ - ServerName: remoteDomain, - Certificates: d.router.Certificates(), - } - tr := transport.NewSocketTransport(conn, d.cfg.Transport.KeepAlive) - return &streamConfig{ - keyGen: &keyGen{secret: d.cfg.DialbackSecret}, - localDomain: localDomain, - remoteDomain: remoteDomain, - transport: tr, - tls: tlsConfig, - maxStanzaSize: d.cfg.MaxStanzaSize, - }, nil -} - -func (d *dialer) dial(localDomain, remoteDomain string) (*streamConfig, error) { - var ret *streamConfig - var err error - isSCIONAddress, remote := rainsLookup(remoteDomain) - if isSCIONAddress { - ret, err = d.dialQUIC(remote, localDomain, remoteDomain) - } else { - ret, err = d.dialTCP(localDomain, remoteDomain) - } - if err != nil { - return nil, err - } - return ret, nil -} - -func rainsLookup(remoteDomain string) (bool, *snet.Addr) { - host, port, err := net.SplitHostPort(remoteDomain) - if err != nil { - host = remoteDomain - port = "52690" - } - ia, l3, err := scionutil.GetHostByName(host + ".") - if err != nil { - return false, nil - } - - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - p = 52690 - } - l4 := libaddr.NewL4UDPInfo(uint16(p)) - raddr := &snet.Addr{IA: ia, Host: &libaddr.AppAddr{L3: l3, L4: l4}} - - return true, raddr -} diff --git a/s2s/dialer.go b/s2s/dialer.go new file mode 100644 index 000000000..5de59c81f --- /dev/null +++ b/s2s/dialer.go @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package s2s + +import ( + "context" + "net" + "strconv" + "strings" + + "github.com/ortuman/jackal/log" + + "github.com/lucas-clemente/quic-go" + "github.com/netsec-ethz/scion-apps/lib/scionutil" + "github.com/scionproto/scion/go/lib/sciond" + "github.com/scionproto/scion/go/lib/snet" + "github.com/scionproto/scion/go/lib/snet/squic" + "github.com/scionproto/scion/go/lib/sock/reliable" + + libaddr "github.com/scionproto/scion/go/lib/addr" +) + +type Dialer interface { + DialTCP(ctx context.Context, remoteDomain string) (net.Conn, error) + DialQUIC(cfg *ScionConfig, remote *snet.Addr, localDomain, remoteDomain string) (quic.Session, error) +} + +type srvResolveFunc func(service, proto, name string) (cname string, addrs []*net.SRV, err error) +type dialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +type dialer struct { + srvResolve srvResolveFunc + dialContext dialFunc +} + +func newDialer() *dialer { + var d net.Dialer + return &dialer{ + srvResolve: net.LookupSRV, + dialContext: d.DialContext, + } +} + +func (d *dialer) DialTCP(ctx context.Context, remoteDomain string) (net.Conn, error) { + _, address, err := d.srvResolve("xmpp-server", "tcp", remoteDomain) + if err != nil { + log.Warnf("srv lookup error: %v", err) + } + var target string + + if err != nil || len(address) == 1 && address[0].Target == "." { + target = remoteDomain + ":5269" + } else { + target = strings.TrimSuffix(address[0].Target, ".") + ":" + strconv.Itoa(int(address[0].Port)) + } + conn, err := d.dialContext(ctx, "tcp", target) + if err != nil { + return nil, err + } + return conn, err +} + +func (d *dialer) DialQUIC(cfg *ScionConfig, remote *snet.Addr, localDomain, remoteDomain string) (quic.Session, error) { + var local *snet.Addr + var err error + if cfg.Address == "localhost" { + local, err = scionutil.GetLocalhost() + } else { + local, err = snet.AddrFromString(cfg.Address) + } + if err != nil { + return nil, err + } + + sciondPath := sciond.GetDefaultSCIONDPath(nil) + dispatcherPath := cfg.Dispatcher + snet.Init(local.IA, sciondPath, reliable.NewDispatcherService(dispatcherPath)) + quicConfig := &quic.Config{ + KeepAlive: true, + } + sess, err := squic.DialSCION(nil, local, remote, quicConfig) + if err != nil { + return nil, err + } + return sess, err +} + +func rainsLookup(remoteDomain string) (bool, *snet.Addr) { + host, port, err := net.SplitHostPort(remoteDomain) + if err != nil { + host = remoteDomain + port = "52690" + } + ia, l3, err := scionutil.GetHostByName(host + ".") + if err != nil { + return false, nil + } + + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + p = 52690 + } + l4 := libaddr.NewL4UDPInfo(uint16(p)) + raddr := &snet.Addr{IA: ia, Host: &libaddr.AppAddr{L3: l3, L4: l4}} + + return true, raddr +} diff --git a/s2s/dial_test.go b/s2s/dialer_test.go similarity index 55% rename from s2s/dial_test.go rename to s2s/dialer_test.go index 9164ac220..77d58ccfa 100644 --- a/s2s/dial_test.go +++ b/s2s/dialer_test.go @@ -6,36 +6,23 @@ package s2s import ( + "context" "errors" "net" "testing" - "time" "github.com/stretchr/testify/require" ) -func TestS2SDial(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() - - cfg := &Config{ - ConnectTimeout: time.Second * time.Duration(5), - MaxStanzaSize: 8192, - Transport: TransportConfig{ - Port: 9778, - KeepAlive: time.Duration(600) * time.Second, - }, - } - - // not enabled - d := newDialer(cfg, r) +func TestDialer_Dial(t *testing.T) { + d := newDialer() // resolver error... mockedErr := errors.New("dialer mocked error") d.srvResolve = func(_, _, _ string) (cname string, addrs []*net.SRV, err error) { return "", nil, mockedErr } - out, err := d.dial("jackal.im", "jabber.org") + out, err := d.DialTCP(context.Background(), "jabber.org") require.NotNil(t, out) require.Nil(t, err) @@ -43,18 +30,18 @@ func TestS2SDial(t *testing.T) { d.srvResolve = func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { return "", []*net.SRV{{Target: "xmpp.jabber.org", Port: 5269}}, nil } - d.dialTimeout = func(_, _ string, _ time.Duration) (net.Conn, error) { + d.dialContext = func(_ context.Context, _, _ string) (net.Conn, error) { return nil, mockedErr } - out, err = d.dial("jackal.im", "jabber.org") + out, err = d.DialTCP(context.Background(), "jabber.org") require.Nil(t, out) require.Equal(t, mockedErr, err) // success - d.dialTimeout = func(_, _ string, _ time.Duration) (net.Conn, error) { + d.dialContext = func(_ context.Context, _, _ string) (net.Conn, error) { return newFakeSocketConn(), nil } - out, err = d.dial("jackal.im", "jabber.org") + out, err = d.DialTCP(context.Background(), "jabber.org") require.NotNil(t, out) require.Nil(t, err) } diff --git a/s2s/in.go b/s2s/in.go index b06c56d42..83b80cd34 100644 --- a/s2s/in.go +++ b/s2s/in.go @@ -6,8 +6,10 @@ package s2s import ( + "context" "crypto/tls" "fmt" + "sync" "sync/atomic" "time" @@ -15,8 +17,9 @@ import ( "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/runqueue" "github.com/ortuman/jackal/session" + "github.com/ortuman/jackal/transport" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) @@ -29,27 +32,31 @@ const ( type inStream struct { id string - cfg *streamConfig - router *router.Router + cfg *inConfig + router router.Router mods *module.Modules localDomain string remoteDomain string state uint32 + tr transport.Transport + mu sync.RWMutex connectTm *time.Timer + readTimeoutTm *time.Timer sess *session.Session secured uint32 authenticated uint32 + newOut newOutFunc runQueue *runqueue.RunQueue } -func newInStream(config *streamConfig, mods *module.Modules, router *router.Router, - alreadySecuredAndAuthd bool) *inStream { - +func newInStream(config *inConfig, tr transport.Transport, mods *module.Modules, newOutFn newOutFunc, router router.Router, alreadySecuredAndAuthd bool) *inStream { id := nextInID() s := &inStream{ id: id, cfg: config, + tr: tr, router: router, + newOut: newOutFn, mods: mods, runQueue: runqueue.New(id), } @@ -71,55 +78,63 @@ func (s *inStream) ID() string { return s.id } -func (s *inStream) Disconnect(err error) { +func (s *inStream) Disconnect(ctx context.Context, err error) { if s.getState() == inDisconnected { return } waitCh := make(chan struct{}) s.runQueue.Run(func() { - s.disconnect(err) + s.disconnect(ctx, err) close(waitCh) }) <-waitCh } func (s *inStream) connectTimeout() { - s.runQueue.Run(func() { s.disconnect(streamerror.ErrConnectionTimeout) }) + s.runQueue.Run(func() { + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) + s.disconnect(ctx, streamerror.ErrConnectionTimeout) + }) } // runs on its own goroutine func (s *inStream) doRead() { - if elem, sErr := s.sess.Receive(); sErr == nil { + s.scheduleReadTimeout() + elem, sErr := s.sess.Receive() + s.cancelReadTimeout() + + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) + if sErr == nil { s.runQueue.Run(func() { - s.readElement(elem) + s.readElement(ctx, elem) }) } else { s.runQueue.Run(func() { if s.getState() == inDisconnected { return // already disconnected... } - s.handleSessionError(sErr) + s.handleSessionError(ctx, sErr) }) } } -func (s *inStream) handleElement(elem xmpp.XElement) { +func (s *inStream) handleElement(ctx context.Context, elem xmpp.XElement) { switch s.getState() { case inConnecting: - s.handleConnecting(elem) + s.handleConnecting(ctx, elem) case inConnected: - s.handleConnected(elem) + s.handleConnected(ctx, elem) } } -func (s *inStream) handleConnecting(elem xmpp.XElement) { +func (s *inStream) handleConnecting(ctx context.Context, elem xmpp.XElement) { // cancel connection timeout timer if s.connectTm != nil { s.connectTm.Stop() s.connectTm = nil } // assign domain pair - s.localDomain = elem.To() + s.localDomain = s.router.Hosts().DefaultHostName() s.remoteDomain = elem.From() // open stream session @@ -133,15 +148,15 @@ func (s *inStream) handleConnecting(elem xmpp.XElement) { features.SetAttribute("version", "1.0") if !s.isSecured() { - starttls := xmpp.NewElementNamespace("starttls", tlsNamespace) - starttls.AppendElement(xmpp.NewElementName("required")) - features.AppendElement(starttls) + startTLS := xmpp.NewElementNamespace("starttls", tlsNamespace) + startTLS.AppendElement(xmpp.NewElementName("required")) + features.AppendElement(startTLS) s.setState(inConnected) - _ = s.sess.Open(features) + _ = s.sess.Open(ctx, features) return } - _ = s.sess.Open(nil) + _ = s.sess.Open(ctx, nil) if !s.isAuthenticated() { // offer external authentication @@ -157,81 +172,82 @@ func (s *inStream) handleConnecting(elem xmpp.XElement) { features.AppendElement(dbBack) s.setState(inConnected) - s.writeElement(features) + s.writeElement(ctx, features) } -func (s *inStream) handleConnected(elem xmpp.XElement) { +func (s *inStream) handleConnected(ctx context.Context, elem xmpp.XElement) { if !s.isSecured() { - s.proceedStartTLS(elem) + s.proceedStartTLS(ctx, elem) return } if !s.isAuthenticated() && elem.Name() == "auth" { - s.startAuthentication(elem) + s.startAuthentication(ctx, elem) return } switch elem.Name() { case "db:result": - s.authorizeDialbackKey(elem) + s.authorizeDialbackKey(ctx, elem) case "db:verify": - s.verifyDialbackKey(elem) + s.verifyDialbackKey(ctx, elem) default: switch elem := elem.(type) { case xmpp.Stanza: - s.processStanza(elem) + s.processStanza(ctx, elem) } } } -func (s *inStream) processStanza(stanza xmpp.Stanza) { +func (s *inStream) processStanza(ctx context.Context, stanza xmpp.Stanza) { switch stanza := stanza.(type) { case *xmpp.Presence: - s.processPresence(stanza) + s.processPresence(ctx, stanza) case *xmpp.IQ: - s.processIQ(stanza) + s.processIQ(ctx, stanza) case *xmpp.Message: - s.processMessage(stanza) + s.processMessage(ctx, stanza) } } -func (s *inStream) processPresence(presence *xmpp.Presence) { +func (s *inStream) processPresence(ctx context.Context, presence *xmpp.Presence) { // process roster presence if presence.ToJID().IsBare() { if r := s.mods.Roster; r != nil { - s.mods.Roster.ProcessPresence(presence) + r.ProcessPresence(ctx, presence) + return } - return } - _ = s.router.Route(presence) + _ = s.router.Route(ctx, presence) } -func (s *inStream) processIQ(iq *xmpp.IQ) { +func (s *inStream) processIQ(ctx context.Context, iq *xmpp.IQ) { toJID := iq.ToJID() - replyOnBehalf := !toJID.IsFullWithUser() && s.router.IsLocalHost(toJID.Domain()) + replyOnBehalf := !toJID.IsFullWithUser() && (s.router.Hosts().IsLocalHost(toJID.Domain()) || + s.router.Hosts().IsConferenceHost(toJID.Domain())) if !replyOnBehalf { - switch s.router.Route(iq) { + switch s.router.Route(ctx, iq) { case router.ErrResourceNotFound: - s.writeElement(iq.ServiceUnavailableError()) + s.writeElement(ctx, iq.ServiceUnavailableError()) case router.ErrFailedRemoteConnect: - s.writeElement(iq.RemoteServerNotFoundError()) + s.writeElement(ctx, iq.RemoteServerNotFoundError()) case router.ErrBlockedJID: // Destination user is a blocked JID if iq.IsGet() || iq.IsSet() { - s.writeElement(iq.ServiceUnavailableError()) + s.writeElement(ctx, iq.ServiceUnavailableError()) } } return } - s.mods.ProcessIQ(iq) + s.mods.ProcessIQ(ctx, iq) } -func (s *inStream) processMessage(message *xmpp.Message) { +func (s *inStream) processMessage(ctx context.Context, message *xmpp.Message) { msg := message sendMessage: - err := s.router.Route(msg) + err := s.router.Route(ctx, msg) switch err { case nil: break @@ -241,7 +257,7 @@ sendMessage: goto sendMessage case router.ErrNotAuthenticated: if off := s.mods.Offline; off != nil { - off.ArchiveMessage(message) + off.ArchiveMessage(ctx, message) return } default: @@ -250,21 +266,21 @@ sendMessage: } } -func (s *inStream) proceedStartTLS(elem xmpp.XElement) { +func (s *inStream) proceedStartTLS(ctx context.Context, elem xmpp.XElement) { if elem.Namespace() != tlsNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidNamespace) return } else if elem.Name() != "starttls" { - s.disconnectWithStreamError(streamerror.ErrNotAuthorized) + s.disconnectWithStreamError(ctx, streamerror.ErrNotAuthorized) return } - s.writeElement(xmpp.NewElementNamespace("proceed", tlsNamespace)) + s.writeElement(ctx, xmpp.NewElementNamespace("proceed", tlsNamespace)) - s.cfg.transport.StartTLS(&tls.Config{ + s.tr.StartTLS(&tls.Config{ ServerName: s.localDomain, ClientAuth: tls.VerifyClientCertIfGiven, - Certificates: s.router.Certificates(), + Certificates: s.router.Hosts().Certificates(), }, false) atomic.StoreUint32(&s.secured, 1) @@ -272,38 +288,38 @@ func (s *inStream) proceedStartTLS(elem xmpp.XElement) { s.restartSession() } -func (s *inStream) startAuthentication(elem xmpp.XElement) { +func (s *inStream) startAuthentication(ctx context.Context, elem xmpp.XElement) { if elem.Namespace() != saslNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidNamespace) return } if elem.Attributes().Get("mechanism") != "EXTERNAL" { - s.failAuthentication("invalid-mechanism", "") + s.failAuthentication(ctx, "invalid-mechanism", "") return } // validate initiating server certificate - certs := s.cfg.transport.PeerCertificates() + certs := s.tr.PeerCertificates() for _, cert := range certs { for _, dnsName := range cert.DNSNames { if dnsName == s.remoteDomain { - s.finishAuthentication() + s.finishAuthentication(ctx) return } } } - s.failAuthentication("bad-protocol", "failed to get peer certificate") + s.failAuthentication(ctx, "bad-protocol", "failed to get peer certificate") } -func (s *inStream) finishAuthentication() { +func (s *inStream) finishAuthentication(ctx context.Context) { log.Infof("s2s in stream authenticated") atomic.StoreUint32(&s.authenticated, 1) success := xmpp.NewElementNamespace("success", saslNamespace) - s.writeElement(success) + s.writeElement(ctx, success) s.restartSession() } -func (s *inStream) failAuthentication(reason, text string) { +func (s *inStream) failAuthentication(ctx context.Context, reason, text string) { log.Infof("failed s2s in stream authentication: %s (text: %s)", reason, text) failure := xmpp.NewElementNamespace("failure", saslNamespace) failure.AppendElement(xmpp.NewElementName(reason)) @@ -312,37 +328,24 @@ func (s *inStream) failAuthentication(reason, text string) { textEl.SetText(text) failure.AppendElement(textEl) } - s.writeElement(failure) + s.writeElement(ctx, failure) } -func (s *inStream) authorizeDialbackKey(elem xmpp.XElement) { - if !s.router.IsLocalHost(elem.To()) { - s.writeStanzaErrorResponse(elem, xmpp.ErrItemNotFound) +func (s *inStream) authorizeDialbackKey(ctx context.Context, elem xmpp.XElement) { + if !s.router.Hosts().IsLocalHost(elem.To()) { + s.writeStanzaErrorResponse(ctx, elem, xmpp.ErrItemNotFound) return } log.Infof("authorizing dialback key: %s...", elem.Text()) - outCfg, err := s.cfg.dialer.dial(elem.To(), elem.From()) - if err != nil { - log.Error(err) - s.writeStanzaErrorResponse(elem, xmpp.ErrRemoteServerNotFound) - return - } - // create verify element - dbVerify := xmpp.NewElementName("db:verify") - dbVerify.SetID(s.sess.StreamID()) - dbVerify.SetFrom(elem.To()) - dbVerify.SetTo(elem.From()) - dbVerify.SetText(elem.Text()) - outCfg.dbVerify = dbVerify + // verify stream + outStm := s.newOut(s.router.Hosts().DefaultHostName(), elem.From()) - isScionAddress, _ := rainsLookup(elem.From()) - outStm := newOutStream(s.router, isScionAddress) - _ = outStm.start(outCfg) + verifyCh := outStm.verify(ctx, s.sess.StreamID(), elem.To(), elem.From(), elem.Text()) // wait remote server verification select { - case valid := <-outStm.verify(): + case valid := <-verifyCh: reply := xmpp.NewElementName("db:result") reply.SetFrom(elem.To()) reply.SetTo(elem.From()) @@ -353,19 +356,19 @@ func (s *inStream) authorizeDialbackKey(elem xmpp.XElement) { } else { reply.SetType("invalid") } - s.writeElement(reply) - outStm.Disconnect(nil) + s.writeElement(ctx, reply) + outStm.Disconnect(ctx, nil) case <-outStm.done(): // remote server closed connection unexpectedly - s.writeStanzaErrorResponse(elem, xmpp.ErrRemoteServerTimeout) + s.writeStanzaErrorResponse(ctx, elem, xmpp.ErrRemoteServerTimeout) break } } -func (s *inStream) verifyDialbackKey(elem xmpp.XElement) { - if !s.router.IsLocalHost(elem.To()) { - s.writeStanzaErrorResponse(elem, xmpp.ErrItemNotFound) +func (s *inStream) verifyDialbackKey(ctx context.Context, elem xmpp.XElement) { + if !s.router.Hosts().IsLocalHost(elem.To()) { + s.writeStanzaErrorResponse(ctx, elem, xmpp.ErrItemNotFound) return } dbVerify := xmpp.NewElementName("db:verify") @@ -381,96 +384,116 @@ func (s *inStream) verifyDialbackKey(elem xmpp.XElement) { log.Infof("failed dialback key verification... (expected: %s, got: %s)", expectedKey, elem.Text()) dbVerify.SetType("invalid") } - s.writeElement(dbVerify) + s.writeElement(ctx, dbVerify) } -func (s *inStream) writeStanzaErrorResponse(elem xmpp.XElement, stanzaErr *xmpp.StanzaError) { +func (s *inStream) writeStanzaErrorResponse(ctx context.Context, elem xmpp.XElement, stanzaErr *xmpp.StanzaError) { resp := xmpp.NewElementFromElement(elem) resp.SetType(xmpp.ErrorType) resp.SetFrom(elem.To()) resp.SetTo(elem.From()) resp.AppendElement(stanzaErr.Element()) - s.writeElement(resp) + s.writeElement(ctx, resp) } -func (s *inStream) writeElement(elem xmpp.XElement) { - s.sess.Send(elem) +func (s *inStream) writeElement(ctx context.Context, elem xmpp.XElement) { + if err := s.sess.Send(ctx, elem); err != nil { + log.Error(err) + } } -func (s *inStream) readElement(elem xmpp.XElement) { +func (s *inStream) readElement(ctx context.Context, elem xmpp.XElement) { if elem != nil { - s.handleElement(elem) + s.handleElement(ctx, elem) } if s.getState() != inDisconnected { go s.doRead() } } -func (s *inStream) handleSessionError(sErr *session.Error) { +func (s *inStream) handleSessionError(ctx context.Context, sErr *session.Error) { switch err := sErr.UnderlyingErr.(type) { case nil: - s.disconnect(nil) + s.disconnect(ctx, nil) case *streamerror.Error: - s.disconnectWithStreamError(err) + s.disconnectWithStreamError(ctx, err) case *xmpp.StanzaError: - s.writeStanzaErrorResponse(sErr.Element, err) + s.writeStanzaErrorResponse(ctx, sErr.Element, err) default: log.Error(err) - s.disconnectWithStreamError(streamerror.ErrUndefinedCondition) + s.disconnectWithStreamError(ctx, streamerror.ErrUndefinedCondition) } } -func (s *inStream) disconnect(err error) { +func (s *inStream) disconnect(ctx context.Context, err error) { if s.getState() == inDisconnected { return } switch err { case nil: - s.disconnectClosingSession(false) + s.disconnectClosingSession(ctx, false) default: if stmErr, ok := err.(*streamerror.Error); ok { - s.disconnectWithStreamError(stmErr) + s.disconnectWithStreamError(ctx, stmErr) } else { log.Error(err) - s.disconnectClosingSession(false) + s.disconnectClosingSession(ctx, false) } } } -func (s *inStream) disconnectWithStreamError(err *streamerror.Error) { +func (s *inStream) disconnectWithStreamError(ctx context.Context, err *streamerror.Error) { if s.getState() == inConnecting { - _ = s.sess.Open(nil) + _ = s.sess.Open(ctx, nil) } - s.writeElement(err.Element()) - s.disconnectClosingSession(true) + s.writeElement(ctx, err.Element()) + s.disconnectClosingSession(ctx, true) } -func (s *inStream) disconnectClosingSession(closeSession bool) { +func (s *inStream) disconnectClosingSession(ctx context.Context, closeSession bool) { if closeSession { - _ = s.sess.Close() + _ = s.sess.Close(ctx) } - if s.cfg.onInDisconnect != nil { - s.cfg.onInDisconnect(s) + if s.cfg.onDisconnect != nil { + s.cfg.onDisconnect(s) } s.setState(inDisconnected) - _ = s.cfg.transport.Close() + _ = s.tr.Close() s.runQueue.Stop(nil) // stop processing messages } func (s *inStream) restartSession() { - j, _ := jid.New("", s.cfg.localDomain, "", true) + j, _ := jid.New("", s.localDomain, "", true) s.sess = session.New(s.id, &session.Config{ JID: j, - Transport: s.cfg.transport, MaxStanzaSize: s.cfg.maxStanzaSize, RemoteDomain: s.remoteDomain, IsServer: true, - }, s.router) + }, s.tr, s.router.Hosts()) s.setState(inConnecting) } +func (s *inStream) scheduleReadTimeout() { + s.mu.Lock() + s.readTimeoutTm = time.AfterFunc(s.cfg.keepAlive, s.readTimeout) + s.mu.Unlock() +} + +func (s *inStream) cancelReadTimeout() { + s.mu.Lock() + s.readTimeoutTm.Stop() + s.mu.Unlock() +} + +func (s *inStream) readTimeout() { + s.runQueue.Run(func() { + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) + s.disconnect(ctx, streamerror.ErrConnectionTimeout) + }) +} + func (s *inStream) isSecured() bool { return atomic.LoadUint32(&s.secured) == 1 } diff --git a/s2s/in_test.go b/s2s/in_test.go index 86515a0de..cae5ef929 100644 --- a/s2s/in_test.go +++ b/s2s/in_test.go @@ -6,6 +6,7 @@ package s2s import ( + "context" "crypto/x509" "fmt" "net" @@ -14,14 +15,10 @@ import ( "time" "github.com/ortuman/jackal/module" - "github.com/ortuman/jackal/module/offline" - "github.com/ortuman/jackal/module/xep0077" - "github.com/ortuman/jackal/module/xep0092" - "github.com/ortuman/jackal/module/xep0199" "github.com/ortuman/jackal/router" "github.com/ortuman/jackal/stream" "github.com/ortuman/jackal/transport" - "github.com/ortuman/jackal/util" + utiltls "github.com/ortuman/jackal/util/tls" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/pborman/uuid" @@ -29,31 +26,34 @@ import ( ) func TestStream_ConnectTimeout(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) - stm, _ := tUtilInStreamInit(t, r, false) + op := NewOutProvider(&Config{KeepAlive: time.Second}, h) + + stm, _ := tUtilInStreamInit(t, r, op, false) time.Sleep(time.Millisecond * 1500) require.Equal(t, inDisconnected, stm.getState()) } func TestStream_Disconnect(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) + + op := NewOutProvider(&Config{}, h) - stm, conn := tUtilInStreamInit(t, r, false) - stm.Disconnect(nil) + stm, conn := tUtilInStreamInit(t, r, op, false) + stm.Disconnect(context.Background(), nil) require.True(t, conn.waitClose()) require.Equal(t, inDisconnected, stm.getState()) } func TestStream_Features(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) + + op := NewOutProvider(&Config{KeepAlive: time.Second}, h) // unsecured features - stm, conn := tUtilInStreamInit(t, r, false) + stm, conn := tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) elem := conn.outboundRead() @@ -65,7 +65,7 @@ func TestStream_Features(t *testing.T) { require.Equal(t, inConnected, stm.getState()) // secured features - stm, conn = tUtilInStreamInit(t, r, false) + stm, conn = tUtilInStreamInit(t, r, op, false) atomic.StoreUint32(&stm.secured, 1) tUtilInStreamOpen(conn) @@ -78,7 +78,7 @@ func TestStream_Features(t *testing.T) { require.Equal(t, inConnected, stm.getState()) // secured features (authenticated) - stm, conn = tUtilInStreamInit(t, r, false) + stm, conn = tUtilInStreamInit(t, r, op, false) atomic.StoreUint32(&stm.secured, 1) atomic.StoreUint32(&stm.authenticated, 1) tUtilInStreamOpen(conn) @@ -93,33 +93,34 @@ func TestStream_Features(t *testing.T) { } func TestStream_TLS(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) + + op := NewOutProvider(&Config{KeepAlive: time.Second}, h) - stm, conn := tUtilInStreamInit(t, r, false) + stm, conn := tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... // wrong namespace... - conn.inboundWriteString(``) + _, _ = conn.inboundWriteString(``) require.True(t, conn.waitClose()) - stm, conn = tUtilInStreamInit(t, r, false) + stm, conn = tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... // wrong name... - conn.inboundWriteString(``) + _, _ = conn.inboundWriteString(``) require.True(t, conn.waitClose()) - stm, conn = tUtilInStreamInit(t, r, false) + stm, conn = tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... - conn.inboundWriteString(``) + _, _ = conn.inboundWriteString(``) elem := conn.outboundRead() @@ -130,61 +131,63 @@ func TestStream_TLS(t *testing.T) { } func TestStream_Authenticate(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) - stm, conn := tUtilInStreamInit(t, r, false) + op := NewOutProvider(&Config{KeepAlive: time.Second}, h) + + stm, conn := tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... atomic.StoreUint32(&stm.secured, 1) // invalid namespace... - conn.inboundWriteString(`=`) + _, _ = conn.inboundWriteString(`=`) require.True(t, conn.waitClose()) - stm, conn = tUtilInStreamInit(t, r, true) + stm, conn = tUtilInStreamInit(t, r, op, true) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... atomic.StoreUint32(&stm.secured, 1) // failed peer certificate... - stm, conn = tUtilInStreamInit(t, r, false) + stm, conn = tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(`=`) + _, _ = conn.inboundWriteString(`=`) elem := conn.outboundRead() require.Equal(t, "failure", elem.Name()) require.Equal(t, saslNamespace, elem.Namespace()) // invalid mechanism... - stm, conn = tUtilInStreamInit(t, r, true) + stm, conn = tUtilInStreamInit(t, r, op, true) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(`=`) + _, _ = conn.inboundWriteString(`=`) elem = conn.outboundRead() require.Equal(t, "failure", elem.Name()) require.Equal(t, saslNamespace, elem.Namespace()) // valid auth... - conn.inboundWriteString(`=`) + _, _ = conn.inboundWriteString(`=`) elem = conn.outboundRead() require.Equal(t, "success", elem.Name()) require.Equal(t, saslNamespace, elem.Namespace()) } func TestStream_DialbackVerify(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) + + op := NewOutProvider(&Config{KeepAlive: time.Second}, h) - stm, conn := tUtilInStreamInit(t, r, false) + stm, conn := tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -192,14 +195,14 @@ func TestStream_DialbackVerify(t *testing.T) { atomic.StoreUint32(&stm.authenticated, 1) // invalid host - conn.inboundWriteString(`abcd`) + _, _ = conn.inboundWriteString(`abcd`) elem := conn.outboundRead() require.Equal(t, "db:verify", elem.Name()) require.NotNil(t, elem.Elements().Child("error")) require.NotNil(t, elem.Elements().Child("error").Elements().Child("item-not-found")) // invalid key - conn.inboundWriteString(`abcd`) + _, _ = conn.inboundWriteString(`abcd`) elem = conn.outboundRead() require.Equal(t, "db:verify", elem.Name()) require.Equal(t, "invalid", elem.Type()) @@ -207,40 +210,41 @@ func TestStream_DialbackVerify(t *testing.T) { // valid key kg := &keyGen{secret: "s3cr3t"} key := kg.generate("localhost", "jackal.im", "abcde") - conn.inboundWriteString(fmt.Sprintf(`%s`, key)) + + _, _ = conn.inboundWriteString(fmt.Sprintf(`%s`, key)) elem = conn.outboundRead() require.Equal(t, "db:verify", elem.Name()) require.Equal(t, "valid", elem.Type()) } func TestStream_DialbackAuthorize(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) + + op := NewOutProvider(&Config{KeepAlive: time.Second}, h) + op.dialer.(*dialer).srvResolve = func(_, _, _ string) (cname string, addrs []*net.SRV, err error) { + return "", []*net.SRV{{Target: "jackal.im", Port: 5269}}, nil + } - stm, conn := tUtilInStreamInit(t, r, false) + stm, conn := tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... atomic.StoreUint32(&stm.secured, 1) atomic.StoreUint32(&stm.authenticated, 1) - conn.inboundWriteString(`abcd`) + _, _ = conn.inboundWriteString(`abcd`) elem := conn.outboundRead() require.Equal(t, "db:result", elem.Name()) require.Equal(t, xmpp.ErrorType, elem.Type()) require.NotNil(t, elem.Elements().Child("error")) require.NotNil(t, elem.Elements().Child("error").Elements().Child("item-not-found")) - cfg, conn := tUtilInStreamDefaultConfig(t, false) - cfg.dialer = &dialer{cfg: &Config{DialTimeout: time.Second}, router: r} - cfg.dialer.srvResolve = func(_, _, _ string) (cname string, addrs []*net.SRV, err error) { - return "", []*net.SRV{{Target: "jackal.im", Port: 5269}}, nil - } + cfg, tr, conn := tUtilInStreamDefaultConfig(t, false) outConn := newFakeSocketConn() - cfg.dialer.dialTimeout = func(_, _ string, _ time.Duration) (net.Conn, error) { + op.dialer.(*dialer).dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) { return outConn, nil } - stm = newInStream(cfg, &module.Modules{}, r, false) + stm = newInStream(cfg, tr, &module.Modules{}, op.newOut, r, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... @@ -248,8 +252,8 @@ func TestStream_DialbackAuthorize(t *testing.T) { atomic.StoreUint32(&stm.secured, 1) atomic.StoreUint32(&stm.authenticated, 1) - conn.inboundWriteString(`abcd`) - outConn.Close() + _, _ = conn.inboundWriteString(`abcd`) + _ = outConn.Close() elem = conn.outboundRead() require.Equal(t, "db:result", elem.Name()) require.Equal(t, xmpp.ErrorType, elem.Type()) @@ -257,16 +261,12 @@ func TestStream_DialbackAuthorize(t *testing.T) { require.NotNil(t, elem.Elements().Child("error").Elements().Child("remote-server-timeout")) // authorize dialback key - cfg, conn = tUtilInStreamDefaultConfig(t, false) - cfg.dialer = &dialer{cfg: &Config{DialTimeout: time.Second}, router: r} - cfg.dialer.srvResolve = func(_, _, _ string) (cname string, addrs []*net.SRV, err error) { - return "", []*net.SRV{{Target: "jackal.im", Port: 5269}}, nil - } + cfg, tr, conn = tUtilInStreamDefaultConfig(t, false) outConn = newFakeSocketConn() - cfg.dialer.dialTimeout = func(_, _ string, _ time.Duration) (net.Conn, error) { + op.dialer.(*dialer).dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) { return outConn, nil } - stm = newInStream(cfg, &module.Modules{}, r, false) + stm = newInStream(cfg, tr, &module.Modules{}, op.newOut, r, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... @@ -274,12 +274,12 @@ func TestStream_DialbackAuthorize(t *testing.T) { atomic.StoreUint32(&stm.secured, 1) atomic.StoreUint32(&stm.authenticated, 1) - conn.inboundWriteString(`abcd`) + _, _ = conn.inboundWriteString(`abcd`) - outConn.inboundWriteString(` + _, _ = outConn.inboundWriteString(` - @@ -290,15 +290,15 @@ func TestStream_DialbackAuthorize(t *testing.T) { _ = outConn.outboundRead() // stream:stream _ = outConn.outboundRead() // starttls - outConn.inboundWriteString(` + _, _ = outConn.inboundWriteString(` `) _ = outConn.outboundRead() // stream:stream - outConn.inboundWriteString(` + _, _ = outConn.inboundWriteString(` - @@ -308,7 +308,7 @@ func TestStream_DialbackAuthorize(t *testing.T) { `) _ = outConn.outboundRead() // db:verify - outConn.inboundWriteString(` + _, _ = outConn.inboundWriteString(` `) elem = conn.outboundRead() @@ -317,16 +317,19 @@ func TestStream_DialbackAuthorize(t *testing.T) { } func TestStream_SendElement(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + r, h := setupTestRouter(jackaDomain) + + op := NewOutProvider(&Config{KeepAlive: time.Second}, h) fromJID, _ := jid.New("ortuman", "localhost", "garden", true) toJID, _ := jid.New("ortuman", "jackal.im", "garden", true) stm2 := stream.NewMockC2S("abcd7890", toJID) - r.Bind(stm2) + stm2.SetPresence(xmpp.NewPresence(toJID, toJID, xmpp.AvailableType)) + + r.Bind(context.Background(), stm2) - stm, conn := tUtilInStreamInit(t, r, false) + stm, conn := tUtilInStreamInit(t, r, op, false) tUtilInStreamOpen(conn) _ = conn.outboundRead() // read stream opening... _ = conn.outboundRead() // read stream features... @@ -337,7 +340,7 @@ func TestStream_SendElement(t *testing.T) { iq := xmpp.NewIQType(iqID, xmpp.ResultType) iq.SetFromJID(fromJID) iq.SetToJID(toJID) - conn.inboundWriteString(iq.String()) + _, _ = conn.inboundWriteString(iq.String()) elem := stm2.ReceiveElement() require.Equal(t, "iq", elem.Name()) @@ -346,13 +349,13 @@ func TestStream_SendElement(t *testing.T) { // invalid from... iq.SetFrom("foo.org") - conn.inboundWriteString(iq.String()) + _, _ = conn.inboundWriteString(iq.String()) require.True(t, conn.waitClose()) } -func tUtilInStreamInit(t *testing.T, router *router.Router, loadPeerCertificate bool) (*inStream, *fakeSocketConn) { - cfg, conn := tUtilInStreamDefaultConfig(t, loadPeerCertificate) - stm := newInStream(cfg, &module.Modules{}, router, false) +func tUtilInStreamInit(t *testing.T, router router.Router, outProvider *OutProvider, loadPeerCertificate bool) (*inStream, *fakeSocketConn) { + cfg, tr, conn := tUtilInStreamDefaultConfig(t, loadPeerCertificate) + stm := newInStream(cfg, tr, &module.Modules{}, outProvider.newOut, router, false) return stm, conn } @@ -361,10 +364,10 @@ func tUtilInStreamOpen(conn *fakeSocketConn) { ` - conn.inboundWriteString(s) + _, _ = conn.inboundWriteString(s) } -func tUtilInStreamDefaultConfig(t *testing.T, loadPeerCertificate bool) (*streamConfig, *fakeSocketConn) { +func tUtilInStreamDefaultConfig(t *testing.T, loadPeerCertificate bool) (*inConfig, transport.Transport, *fakeSocketConn) { modules := map[string]struct{}{} modules["roster"] = struct{}{} modules["last_activity"] = struct{}{} @@ -378,7 +381,7 @@ func tUtilInStreamDefaultConfig(t *testing.T, loadPeerCertificate bool) (*stream certFile := "../testdata/cert/test.server.crt" certKey := "../testdata/cert/test.server.key" - cer, err := util.LoadCertificate(certKey, certFile, "localhost") + cer, err := utiltls.LoadCertificate(certKey, certFile, "localhost") require.Nil(t, err) var peerCerts []*x509.Certificate @@ -392,18 +395,11 @@ func tUtilInStreamDefaultConfig(t *testing.T, loadPeerCertificate bool) (*stream } conn := newFakeSocketConnWithPeerCerts(peerCerts) - tr := transport.NewSocketTransport(conn, 4096) - return &streamConfig{ - modConfig: &module.Config{ - Enabled: modules, - Offline: offline.Config{QueueSize: 10}, - Registration: xep0077.Config{AllowRegistration: true, AllowChange: true}, - Version: xep0092.Config{ShowOS: true}, - Ping: xep0199.Config{SendInterval: 5, Send: true}, - }, + tr := transport.NewSocketTransport(conn) + return &inConfig{ connectTimeout: time.Second, - transport: tr, + keepAlive: time.Second, maxStanzaSize: 8192, keyGen: &keyGen{secret: "s3cr3t"}, - }, conn + }, tr, conn } diff --git a/s2s/out.go b/s2s/out.go index e71a61517..6efb4630e 100644 --- a/s2s/out.go +++ b/s2s/out.go @@ -6,16 +6,18 @@ package s2s import ( + "context" "fmt" + "sync" "sync/atomic" - - "github.com/ortuman/jackal/runqueue" + "time" streamerror "github.com/ortuman/jackal/errors" "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/router/host" "github.com/ortuman/jackal/session" - "github.com/ortuman/jackal/stream" + "github.com/ortuman/jackal/transport" + "github.com/ortuman/jackal/util/runqueue" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) @@ -32,29 +34,33 @@ const ( ) type outStream struct { - started uint32 id string - cfg *streamConfig - router *router.Router + cfg *outConfig + runQueue *runqueue.RunQueue + hosts *host.Hosts + dialer Dialer state uint32 + tr transport.Transport + mu sync.RWMutex sess *session.Session + readTimeoutTm *time.Timer secured uint32 authenticated uint32 - sendQueue []xmpp.XElement - verified chan xmpp.XElement + pendingSendQ []xmpp.XElement + dbVerify xmpp.XElement verifyCh chan bool discCh chan *streamerror.Error - runQueue *runqueue.RunQueue - onDisconnect func(s stream.S2SOut) } -func newOutStream(router *router.Router, alreadySecuredAndAuthd bool) *outStream { +func newOutStream(cfg *outConfig, hosts *host.Hosts, dialer Dialer, alreadySecuredAndAuthd bool) *outStream { id := nextOutID() s := &outStream{ id: id, - router: router, - verifyCh: make(chan bool, 1), - discCh: make(chan *streamerror.Error, 1), + cfg: cfg, + hosts: hosts, + dialer: dialer, + state: outDisconnected, + discCh: make(chan *streamerror.Error), runQueue: runqueue.New(id), } if alreadySecuredAndAuthd { @@ -68,111 +74,159 @@ func (s *outStream) ID() string { return s.cfg.localDomain + ":" + s.cfg.remoteDomain } -func (s *outStream) SendElement(elem xmpp.XElement) { - if s.getState() == outDisconnected { - return - } +func (s *outStream) SendElement(ctx context.Context, elem xmpp.XElement) { s.runQueue.Run(func() { - if s.getState() != outVerified { - // send element after verification has been completed - s.sendQueue = append(s.sendQueue, elem) - return - } - s.writeElement(elem) + s.sendElement(ctx, elem) }) } -func (s *outStream) Disconnect(err error) { - if s.getState() == outDisconnected { - return - } +func (s *outStream) Disconnect(ctx context.Context, err error) { waitCh := make(chan struct{}) - s.runQueue.Run(func() { - s.disconnect(err) - close(waitCh) + s.runQueue.Stop(func() { + defer close(waitCh) + if s.getState() == outDisconnected { + return + } + s.disconnect(ctx, err) }) <-waitCh } -func (s *outStream) start(cfg *streamConfig) error { - if cfg.dbVerify != nil && cfg.dbVerify.Name() != "db:verify" { - return fmt.Errorf("wrong dialback verification element name: %s", cfg.dbVerify.Name()) - } - if !atomic.CompareAndSwapUint32(&s.started, 0, 1) { - return fmt.Errorf("stream already started (domainpair: %s)", s.ID()) +func (s *outStream) sendElement(ctx context.Context, elem xmpp.XElement) { + switch s.getState() { + case outVerified: + s.writeElement(ctx, elem) + case outDisconnected: + if err := s.start(ctx); err != nil { + log.Error(err) + return + } + fallthrough + default: + // send element after verification has been completed + s.pendingSendQ = append(s.pendingSendQ, elem) + return } - s.cfg = cfg +} - // start s2s out session - s.restartSession() +func (s *outStream) verify(ctx context.Context, streamID, from, to, key string) <-chan bool { + verifyCh := make(chan bool, 1) + s.runQueue.Run(func() { + dbVerify := xmpp.NewElementName("db:verify") + dbVerify.SetID(streamID) + dbVerify.SetFrom(from) + dbVerify.SetTo(to) + dbVerify.SetText(key) - go s.doRead() // start reading transport... + s.dbVerify = dbVerify + s.verifyCh = verifyCh - s.runQueue.Run(func() { - _ = s.sess.Open(nil) + if err := s.start(ctx); err != nil { + log.Error(err) + return + } }) - return nil + return verifyCh } -func (s *outStream) verify() <-chan bool { return s.verifyCh } func (s *outStream) done() <-chan *streamerror.Error { return s.discCh } // runs on its own goroutine func (s *outStream) doRead() { - if elem, sErr := s.sess.Receive(); sErr == nil { - s.runQueue.Run(func() { - s.readElement(elem) - }) + s.scheduleReadTimeout() + elem, sErr := s.sess.Receive() + s.cancelReadTimeout() + + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) + if sErr == nil { + s.runQueue.Run(func() { s.readElement(ctx, elem) }) } else { s.runQueue.Run(func() { if s.getState() == outDisconnected { return // already disconnected... } - s.handleSessionError(sErr) + log.Infof("s2s out stream disconnected... (domainpair: %s)", s.ID()) + + s.handleSessionError(ctx, sErr) }) } } -func (s *outStream) handleElement(elem xmpp.XElement) { +func (s *outStream) dial(ctx context.Context) error { + isSCIONAddress, remote := rainsLookup(s.cfg.remoteDomain) + if isSCIONAddress { + sess, err := s.dialer.DialQUIC(s.cfg.scion, remote, s.cfg.localDomain, s.cfg.remoteDomain) + if err != nil { + return err + } + biStream, err := sess.OpenStreamSync() + if err != nil { + return err + } + s.tr = transport.NewQUICSocketTransport(sess, biStream) + } else { + conn, err := s.dialer.DialTCP(ctx, s.cfg.remoteDomain) + if err != nil { + return err + } + s.tr = transport.NewSocketTransport(conn) + } + return nil +} + +func (s *outStream) start(ctx context.Context) error { + if err := s.dial(ctx); err != nil { + return err + } + s.restartSession() + + _ = s.sess.Open(ctx, nil) + + go s.doRead() // start reading transport... + + return nil +} + +func (s *outStream) handleElement(ctx context.Context, elem xmpp.XElement) { switch s.getState() { case outConnecting: - s.handleConnecting(elem) + s.handleConnecting() case outConnected: - s.handleConnected(elem) + s.handleConnected(ctx, elem) case outSecuring: - s.handleSecuring(elem) + s.handleSecuring(ctx, elem) case outAuthenticating: - s.handleAuthenticating(elem) + s.handleAuthenticating(ctx, elem) case outValidatingDialbackKey: - s.handleValidatingDialbackKey(elem) + s.handleValidatingDialbackKey(ctx, elem) case outAuthorizingDialbackKey: - s.handleAuthorizingDialbackKey(elem) + s.handleAuthorizingDialbackKey(ctx, elem) } } -func (s *outStream) handleConnecting(elem xmpp.XElement) { +func (s *outStream) handleConnecting() { s.setState(outConnected) } -func (s *outStream) handleConnected(elem xmpp.XElement) { +func (s *outStream) handleConnected(ctx context.Context, elem xmpp.XElement) { if elem.Name() != "stream:features" { - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) return } if !s.isSecured() { if elem.Elements().ChildrenNamespace("starttls", tlsNamespace) == nil { // unsecured channels not supported - s.disconnectWithStreamError(streamerror.ErrPolicyViolation) + s.disconnectWithStreamError(ctx, streamerror.ErrPolicyViolation) return } s.setState(outSecuring) - s.writeElement(xmpp.NewElementNamespace("starttls", tlsNamespace)) + s.writeElement(ctx, xmpp.NewElementNamespace("starttls", tlsNamespace)) } else { // authorize dialback key - if s.cfg.dbVerify != nil { + if s.dbVerify != nil { s.setState(outAuthorizingDialbackKey) - s.writeElement(s.cfg.dbVerify) + s.writeElement(ctx, s.dbVerify) return } if !s.isAuthenticated() { @@ -190,7 +244,7 @@ func (s *outStream) handleConnected(elem xmpp.XElement) { auth := xmpp.NewElementNamespace("auth", saslNamespace) auth.SetAttribute("mechanism", "EXTERNAL") auth.SetText("=") - s.writeElement(auth) + s.writeElement(ctx, auth) } else if elem.Elements().ChildrenNamespace("dialback", dialbackNamespace) != nil { s.setState(outValidatingDialbackKey) @@ -198,175 +252,198 @@ func (s *outStream) handleConnected(elem xmpp.XElement) { db.SetFrom(s.cfg.localDomain) db.SetTo(s.cfg.remoteDomain) db.SetText(s.cfg.keyGen.generate(s.cfg.remoteDomain, s.cfg.localDomain, s.sess.StreamID())) - s.writeElement(db) + s.writeElement(ctx, db) } else { // no verification mechanism found... do not allow remote connection - s.disconnectWithStreamError(streamerror.ErrRemoteConnectionFailed) + s.disconnectWithStreamError(ctx, streamerror.ErrRemoteConnectionFailed) } } else { - s.finishVerification() + s.finishVerification(ctx) } } } -func (s *outStream) handleSecuring(elem xmpp.XElement) { +func (s *outStream) handleSecuring(ctx context.Context, elem xmpp.XElement) { if elem.Name() != "proceed" { - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) return } else if elem.Namespace() != tlsNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidNamespace) return } - s.cfg.transport.StartTLS(s.cfg.tls, true) + s.tr.StartTLS(s.cfg.tls, true) atomic.StoreUint32(&s.secured, 1) s.restartSession() - _ = s.sess.Open(nil) + + _ = s.sess.Open(ctx, nil) } -func (s *outStream) handleAuthenticating(elem xmpp.XElement) { +func (s *outStream) handleAuthenticating(ctx context.Context, elem xmpp.XElement) { if elem.Namespace() != saslNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidNamespace) return } switch elem.Name() { case "success": atomic.StoreUint32(&s.authenticated, 1) s.restartSession() - _ = s.sess.Open(nil) + _ = s.sess.Open(ctx, nil) case "failure": - s.disconnectWithStreamError(streamerror.ErrRemoteConnectionFailed) + s.disconnectWithStreamError(ctx, streamerror.ErrRemoteConnectionFailed) default: - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) } } -func (s *outStream) handleValidatingDialbackKey(elem xmpp.XElement) { +func (s *outStream) handleValidatingDialbackKey(ctx context.Context, elem xmpp.XElement) { switch elem.Name() { case "db:result": if elem.From() != s.cfg.remoteDomain { - s.disconnectWithStreamError(streamerror.ErrInvalidFrom) + s.disconnectWithStreamError(ctx, streamerror.ErrInvalidFrom) return } switch elem.Type() { case "valid": log.Infof("s2s out stream successfully validated... (domainpair: %s)", s.ID()) - s.finishVerification() + s.finishVerification(ctx) default: log.Infof("failed s2s out stream validation... (domainpair: %s)", s.ID()) - s.disconnectWithStreamError(streamerror.ErrRemoteConnectionFailed) + s.disconnectWithStreamError(ctx, streamerror.ErrRemoteConnectionFailed) } } } -func (s *outStream) handleAuthorizingDialbackKey(elem xmpp.XElement) { +func (s *outStream) handleAuthorizingDialbackKey(ctx context.Context, elem xmpp.XElement) { switch elem.Name() { case "db:verify": s.verifyCh <- elem.Type() == "valid" default: - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + s.disconnectWithStreamError(ctx, streamerror.ErrUnsupportedStanzaType) } } -func (s *outStream) finishVerification() { +func (s *outStream) finishVerification(ctx context.Context) { + s.setState(outVerified) + // send pending elements... - for _, el := range s.sendQueue { - s.writeElement(el) + for _, el := range s.pendingSendQ { + s.writeElement(ctx, el) } - s.sendQueue = nil - s.setState(outVerified) + s.pendingSendQ = nil } -func (s *outStream) writeStanzaErrorResponse(elem xmpp.XElement, stanzaErr *xmpp.StanzaError) { +func (s *outStream) writeStanzaErrorResponse(ctx context.Context, elem xmpp.XElement, stanzaErr *xmpp.StanzaError) { resp := xmpp.NewElementFromElement(elem) resp.SetType(xmpp.ErrorType) resp.SetFrom(elem.To()) resp.SetTo(elem.From()) resp.AppendElement(stanzaErr.Element()) - s.writeElement(resp) + s.writeElement(ctx, resp) } -func (s *outStream) writeElement(elem xmpp.XElement) { - s.sess.Send(elem) +func (s *outStream) writeElement(ctx context.Context, elem xmpp.XElement) { + if err := s.sess.Send(ctx, elem); err != nil { + log.Error(err) + } } -func (s *outStream) readElement(elem xmpp.XElement) { +func (s *outStream) readElement(ctx context.Context, elem xmpp.XElement) { if elem != nil { - s.handleElement(elem) + s.handleElement(ctx, elem) } if s.getState() != outDisconnected { go s.doRead() } } -func (s *outStream) handleSessionError(sErr *session.Error) { +func (s *outStream) handleSessionError(ctx context.Context, sErr *session.Error) { switch err := sErr.UnderlyingErr.(type) { case nil: - s.disconnect(nil) + s.disconnect(ctx, nil) case *streamerror.Error: - s.disconnectWithStreamError(err) + s.disconnectWithStreamError(ctx, err) case *xmpp.StanzaError: - s.writeStanzaErrorResponse(sErr.Element, err) + s.writeStanzaErrorResponse(ctx, sErr.Element, err) default: log.Error(err) - s.disconnectWithStreamError(streamerror.ErrUndefinedCondition) + s.disconnectWithStreamError(ctx, streamerror.ErrUndefinedCondition) } } -func (s *outStream) disconnect(err error) { +func (s *outStream) disconnect(ctx context.Context, err error) { switch err { case nil: - s.disconnectClosingSession(false) + s.disconnectClosingSession(ctx, false) default: if stmErr, ok := err.(*streamerror.Error); ok { - s.disconnectWithStreamError(stmErr) + s.disconnectWithStreamError(ctx, stmErr) } else { log.Error(err) - s.disconnectClosingSession(false) + s.disconnectClosingSession(ctx, false) } } } -func (s *outStream) disconnectWithStreamError(err *streamerror.Error) { - s.discCh <- err - s.writeElement(err.Element()) - s.disconnectClosingSession(true) +func (s *outStream) disconnectWithStreamError(ctx context.Context, err *streamerror.Error) { + // notify disconnection + select { + case s.discCh <- err: + break + default: + break + } + s.writeElement(ctx, err.Element()) + s.disconnectClosingSession(ctx, true) } -func (s *outStream) disconnectClosingSession(closeSession bool) { +func (s *outStream) disconnectClosingSession(ctx context.Context, closeSession bool) { if closeSession { - _ = s.sess.Close() - } - if s.cfg.onOutDisconnect != nil { - s.cfg.onOutDisconnect(s) + _ = s.sess.Close(ctx) } + atomic.StoreUint32(&s.secured, 0) + atomic.StoreUint32(&s.authenticated, 0) s.setState(outDisconnected) - _ = s.cfg.transport.Close() - - s.runQueue.Stop(nil) // stop processing messages - - close(s.discCh) + _ = s.tr.Close() } func (s *outStream) restartSession() { j, _ := jid.New("", s.cfg.localDomain, "", true) s.sess = session.New(s.id, &session.Config{ JID: j, - Transport: s.cfg.transport, MaxStanzaSize: s.cfg.maxStanzaSize, RemoteDomain: s.cfg.remoteDomain, IsServer: true, IsInitiating: true, - }, s.router) + }, s.tr, s.hosts) s.setState(outConnecting) } +func (s *outStream) scheduleReadTimeout() { + s.mu.Lock() + s.readTimeoutTm = time.AfterFunc(s.cfg.keepAlive, s.readTimeout) + s.mu.Unlock() +} + +func (s *outStream) cancelReadTimeout() { + s.mu.Lock() + s.readTimeoutTm.Stop() + s.mu.Unlock() +} + +func (s *outStream) readTimeout() { + s.runQueue.Run(func() { + ctx, _ := context.WithTimeout(context.Background(), s.cfg.timeout) + s.disconnect(ctx, streamerror.ErrConnectionTimeout) + }) +} + func (s *outStream) isSecured() bool { return atomic.LoadUint32(&s.secured) == 1 } diff --git a/s2s/out_test.go b/s2s/out_test.go index 7044fb9fa..4464d659a 100644 --- a/s2s/out_test.go +++ b/s2s/out_test.go @@ -6,255 +6,178 @@ package s2s import ( + "context" + "net" "sync/atomic" "testing" "time" - "github.com/ortuman/jackal/module" - "github.com/ortuman/jackal/module/offline" - "github.com/ortuman/jackal/module/xep0077" - "github.com/ortuman/jackal/module/xep0092" - "github.com/ortuman/jackal/module/xep0199" - "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/transport" + "github.com/ortuman/jackal/router/host" "github.com/ortuman/jackal/xmpp" "github.com/pborman/uuid" "github.com/stretchr/testify/require" ) -func TestOutStream_Start(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() - - cfg, _ := tUtilOutStreamDefaultConfig() - stm := newOutStream(r, false) - defer stm.Disconnect(nil) - - // wrong verification name... - cfg.dbVerify = xmpp.NewElementName("foo") - err := stm.start(cfg) - require.NotNil(t, err) - - cfg.dbVerify = nil - stm.start(cfg) - err = stm.start(cfg) - require.NotNil(t, err) // already started -} - func TestOutStream_Disconnect(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + h := setupTestHosts(jackaDomain) + + cfg, dialer, conn := tUtilOutStreamDefaultConfig() + stm := newOutStream(cfg, h, dialer, false) + _ = stm.start(context.Background()) - cfg, conn := tUtilOutStreamDefaultConfig() - stm := newOutStream(r, false) - stm.start(cfg) - stm.Disconnect(nil) + stm.Disconnect(context.Background(), nil) require.True(t, conn.waitClose()) require.Equal(t, outDisconnected, stm.getState()) } func TestOutStream_BadConnect(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + h := setupTestHosts(jackaDomain) - _, conn := tUtilOutStreamInit(t, r) + _, conn := tUtilOutStreamInit(t, h) // invalid namespace - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) } func TestOutStream_Features(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + h := setupTestHosts(jackaDomain) - _, conn := tUtilOutStreamInit(t, r) + _, conn := tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) // invalid stanza type... - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) // invalid namespace... - _, conn = tUtilOutStreamInit(t, r) + _, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) - conn.inboundWriteString(` + + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) // invalid version... - _, conn = tUtilOutStreamInit(t, r) + _, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) - conn.inboundWriteString(` + + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) // starttls not available... - _, conn = tUtilOutStreamInit(t, r) + _, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) } -func TestOutStream_DBVerify(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() - - cfg, conn := tUtilOutStreamDefaultConfig() - dbVerify := xmpp.NewElementName("db:verify") - key := uuid.New() - dbVerify.SetID("abcde") - dbVerify.SetFrom("jackal.im") - dbVerify.SetTo("jabber.org") - dbVerify.SetText(key) - cfg.dbVerify = dbVerify - - stm := tUtilOutStreamInitWithConfig(t, r, cfg, conn) - atomic.StoreUint32(&stm.secured, 1) - tUtilOutStreamOpen(conn) - - conn.inboundWriteString(securedFeatures) - elem := conn.outboundRead() - require.Equal(t, "db:verify", elem.Name()) - require.Equal(t, key, elem.Text()) - - // unsupported stanza... - conn.inboundWriteString(` - -`) - select { - case sErr := <-stm.done(): - require.Equal(t, "unsupported-stanza-type", sErr.Error()) - case <-time.After(time.Second): - require.Fail(t, "expecting session error") - } - - cfg, conn = tUtilOutStreamDefaultConfig() - cfg.dbVerify = dbVerify - stm = tUtilOutStreamInitWithConfig(t, r, cfg, conn) - atomic.StoreUint32(&stm.secured, 1) - tUtilOutStreamOpen(conn) - conn.inboundWriteString(securedFeatures) - _ = conn.outboundRead() - - conn.inboundWriteString(` - -`) - select { - case ok := <-stm.verify(): - require.True(t, ok) - case <-time.After(time.Second): - require.Fail(t, "expecting dialback valid verification") - } -} - func TestOutStream_StartTLS(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + h := setupTestHosts(jackaDomain) // unsupported stanza... - _, conn := tUtilOutStreamInit(t, r) + _, conn := tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) - conn.inboundWriteString(unsecuredFeatures) + _, _ = conn.inboundWriteString(unsecuredFeatures) elem := conn.outboundRead() require.Equal(t, "starttls", elem.Name()) require.Equal(t, tlsNamespace, elem.Namespace()) - conn.inboundWriteString(``) + _, _ = conn.inboundWriteString(``) require.True(t, conn.waitClose()) // invalid namespace - _, conn = tUtilOutStreamInit(t, r) + _, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) - conn.inboundWriteString(unsecuredFeatures) + _, _ = conn.inboundWriteString(unsecuredFeatures) _ = conn.outboundRead() - conn.inboundWriteString(``) + _, _ = conn.inboundWriteString(``) require.True(t, conn.waitClose()) // valid - stm, conn := tUtilOutStreamInit(t, r) + stm, conn := tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) - conn.inboundWriteString(unsecuredFeatures) + _, _ = conn.inboundWriteString(unsecuredFeatures) _ = conn.outboundRead() - conn.inboundWriteString(``) + _, _ = conn.inboundWriteString(``) _ = conn.outboundRead() require.True(t, stm.isSecured()) } func TestOutStream_Authenticate(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + h := setupTestHosts(jackaDomain) // unsupported stanza... - stm, conn := tUtilOutStreamInit(t, r) + stm, conn := tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(securedFeaturesWithExternal) + _, _ = conn.inboundWriteString(securedFeaturesWithExternal) elem := conn.outboundRead() require.Equal(t, "auth", elem.Name()) require.Equal(t, "urn:ietf:params:xml:ns:xmpp-sasl", elem.Namespace()) require.Equal(t, "EXTERNAL", elem.Attributes().Get("mechanism")) - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) - stm, conn = tUtilOutStreamInit(t, r) + stm, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(securedFeaturesWithExternal) + _, _ = conn.inboundWriteString(securedFeaturesWithExternal) _ = conn.outboundRead() - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) - stm, conn = tUtilOutStreamInit(t, r) + stm, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(securedFeaturesWithExternal) + _, _ = conn.inboundWriteString(securedFeaturesWithExternal) _ = conn.outboundRead() - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) - stm, conn = tUtilOutStreamInit(t, r) + stm, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(securedFeaturesWithExternal) + _, _ = conn.inboundWriteString(securedFeaturesWithExternal) _ = conn.outboundRead() // store pending stanza... iqID := uuid.New() iq := xmpp.NewIQType(iqID, xmpp.GetType) iq.AppendElement(xmpp.NewElementNamespace("query", "jabber:foo")) - stm.SendElement(iq) + stm.SendElement(context.Background(), iq) - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) elem = conn.outboundRead() require.True(t, stm.isAuthenticated()) tUtilOutStreamOpen(conn) - conn.inboundWriteString(securedFeaturesWithExternal) + _, _ = conn.inboundWriteString(securedFeaturesWithExternal) elem = conn.outboundRead() // ...expect receiving pending stanza require.Equal(t, "iq", elem.Name()) @@ -262,49 +185,48 @@ func TestOutStream_Authenticate(t *testing.T) { } func TestOutStream_Dialback(t *testing.T) { - r, _, shutdown := setupTest(jackaDomain) - defer shutdown() + h := setupTestHosts(jackaDomain) // unsupported stanza... - stm, conn := tUtilOutStreamInit(t, r) + stm, conn := tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(securedFeatures) + _, _ = conn.inboundWriteString(securedFeatures) elem := conn.outboundRead() require.Equal(t, "db:result", elem.Name()) // invalid from... - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) // failed - stm, conn = tUtilOutStreamInit(t, r) + stm, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(securedFeatures) + _, _ = conn.inboundWriteString(securedFeatures) _ = conn.outboundRead() - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) require.True(t, conn.waitClose()) // successful - stm, conn = tUtilOutStreamInit(t, r) + stm, conn = tUtilOutStreamInit(t, h) tUtilOutStreamOpen(conn) atomic.StoreUint32(&stm.secured, 1) - conn.inboundWriteString(securedFeatures) + _, _ = conn.inboundWriteString(securedFeatures) _ = conn.outboundRead() iqID := uuid.New() iq := xmpp.NewIQType(iqID, xmpp.GetType) - stm.SendElement(iq) //...store pending... + stm.SendElement(context.Background(), iq) //...store pending... - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) elem = conn.outboundRead() @@ -314,7 +236,7 @@ func TestOutStream_Dialback(t *testing.T) { func tUtilOutStreamOpen(conn *fakeSocketConn) { // open stream from remote server... - conn.inboundWriteString(` + _, _ = conn.inboundWriteString(` `) - case transport.WebSocket: - ops = xmpp.NewElementName("open") - ops.SetAttribute("xmlns", framedStreamNamespace) - includeClosing = true - default: return nil } @@ -173,46 +162,60 @@ func (s *Session) Open(featuresElem xmpp.XElement) error { s.mu.RUnlock() } ops.SetAttribute("version", "1.0") - ops.ToXML(buf, includeClosing) + if err := ops.ToXML(buf, includeClosing); err != nil { + return err + } if featuresElem != nil { - featuresElem.ToXML(buf, true) + if err := featuresElem.ToXML(buf, true); err != nil { + return err + } } openStr := buf.String() log.Debugf("SEND(%s): %s", s.id, openStr) + s.setWriteDeadline(ctx) + _, err := io.Copy(s.tr, strings.NewReader(openStr)) - _ = s.tr.Flush() - return err + if err != nil { + return err + } + return s.tr.Flush() } // Close closes session sending the proper XMPP payload. // Is responsibility of the caller to close underlying transport. -func (s *Session) Close() error { +func (s *Session) Close(ctx context.Context) error { if atomic.LoadUint32(&s.opened) == 0 { return errors.New("session already closed") } + s.setWriteDeadline(ctx) + + var err error switch s.tr.Type() { case transport.Socket: - io.WriteString(s.tr, "") - case transport.WebSocket: - io.WriteString(s.tr, fmt.Sprintf(``, framedStreamNamespace)) + _, err = io.WriteString(s.tr, "") } - _ = s.tr.Flush() - - return nil + if err != nil { + return err + } + return s.tr.Flush() } // Send writes an XML element to the underlying session transport. -func (s *Session) Send(elem xmpp.XElement) { +func (s *Session) Send(ctx context.Context, elem xmpp.XElement) error { // clear namespace if sending a stanza if e, ok := elem.(namespaceSettable); elem.IsStanza() && ok { e.SetNamespace("") } log.Debugf("SEND(%s): %v", s.id, elem) - elem.ToXML(s.tr, true) - _ = s.tr.Flush() + s.setWriteDeadline(ctx) + + if err := elem.ToXML(s.tr, true); err != nil { + return err + } + return s.tr.Flush() } // Receive returns next incoming session element. @@ -245,6 +248,14 @@ func (s *Session) Receive() (xmpp.XElement, *Error) { return elem, nil } +func (s *Session) setWriteDeadline(ctx context.Context) { + d, ok := ctx.Deadline() + if !ok { + return + } + _ = s.tr.SetWriteDeadline(d) +} + func (s *Session) buildStanza(elem xmpp.XElement) (xmpp.Stanza, *Error) { if err := s.validateNamespace(elem); err != nil { return nil, err @@ -340,17 +351,9 @@ func (s *Session) validateStreamElement(elem xmpp.XElement) *Error { if elem.Namespace() != s.namespace() || elem.Attributes().Get("xmlns:stream") != streamNamespace { return &Error{UnderlyingErr: streamerror.ErrInvalidNamespace} } - - case transport.WebSocket: - if elem.Name() != "open" { - return &Error{UnderlyingErr: streamerror.ErrUnsupportedStanzaType} - } - if elem.Namespace() != framedStreamNamespace { - return &Error{UnderlyingErr: streamerror.ErrInvalidNamespace} - } } to := elem.To() - if len(to) > 0 && !s.router.IsLocalHost(to) { + if len(to) > 0 && !s.hosts.IsLocalHost(to) { return &Error{UnderlyingErr: streamerror.ErrHostUnknown} } if elem.Version() != "1.0" { @@ -386,7 +389,7 @@ func (s *Session) mapErrorToSessionError(err error) *Error { break case xmpp.ErrStreamClosedByPeer: - s.Close() + _ = s.Close(context.Background()) case xmpp.ErrTooLargeStanza: return &Error{UnderlyingErr: streamerror.ErrPolicyViolation} diff --git a/session/session_test.go b/session/session_test.go index 24d0ff3c5..90f0694a8 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -7,22 +7,22 @@ package session import ( "bytes" + "context" "crypto/tls" "crypto/x509" stdxml "encoding/xml" + "errors" "io" "testing" + "time" streamerror "github.com/ortuman/jackal/errors" - "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/memstorage" + "github.com/ortuman/jackal/router/host" "github.com/ortuman/jackal/transport" "github.com/ortuman/jackal/transport/compress" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/pborman/uuid" - "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -41,27 +41,27 @@ func (t *fakeTransport) Write(p []byte) (n int, err error) func (t *fakeTransport) Close() error { return nil } func (t *fakeTransport) Type() transport.Type { return t.typ } func (t *fakeTransport) Flush() error { return nil } +func (t *fakeTransport) SetWriteDeadline(_ time.Time) error { return nil } func (t *fakeTransport) WriteString(s string) (n int, err error) { return t.wrBuf.WriteString(s) } -func (t *fakeTransport) StartTLS(cfg *tls.Config, asClient bool) {} +func (t *fakeTransport) StartTLS(_ *tls.Config, _ bool) {} func (t *fakeTransport) EnableCompression(compress.Level) {} func (t *fakeTransport) ChannelBindingBytes(transport.ChannelBindingMechanism) []byte { return nil } func (t *fakeTransport) PeerCertificates() []*x509.Certificate { return nil } func TestSession_Open(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j, _ := jid.NewWithString("jackal.im", true) // test client socket session start tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) - require.NotNil(t, sess.Close()) + require.NotNil(t, sess.Close(context.Background())) _, err1 := sess.Receive() require.NotNil(t, err1) - sess.Open(nil) + _ = sess.Open(context.Background(), nil) pr := xmpp.NewParser(tr.wrBuf, xmpp.SocketStream, 0) _, _ = pr.ParseElement() // read xml header elem, err := pr.ParseElement() @@ -72,113 +72,99 @@ func TestSession_Open(t *testing.T) { // test server socket session start tr.wrBuf.Reset() - sess = New(uuid.New(), &Config{JID: j, Transport: tr, IsServer: true}, r) - sess.Open(nil) + sess = New(uuid.New(), &Config{JID: j, IsServer: true}, tr, hosts) + + _ = sess.Open(context.Background(), nil) pr = xmpp.NewParser(tr.wrBuf, xmpp.SocketStream, 0) _, _ = pr.ParseElement() // read xml header elem, err = pr.ParseElement() require.Nil(t, err) require.Equal(t, "jabber:server", elem.Namespace()) - // test websocket session start - tr = newFakeTransport(transport.WebSocket) - sess = New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) - pr = xmpp.NewParser(tr.wrBuf, xmpp.WebSocketStream, 0) - elem, err = pr.ParseElement() - require.Nil(t, err) - require.Equal(t, "open", elem.Name()) - require.Equal(t, "urn:ietf:params:xml:ns:xmpp-framing", elem.Attributes().Get("xmlns")) - // test unsupported transport type tr = newFakeTransport(transport.Type(9999)) - sess = New(uuid.New(), &Config{JID: j, Transport: tr}, r) - require.Nil(t, sess.Open(nil)) + sess = New(uuid.New(), &Config{JID: j}, tr, hosts) + require.Nil(t, sess.Open(context.Background(), nil)) // open twice - require.NotNil(t, sess.Open(nil)) + require.NotNil(t, sess.Open(context.Background(), nil)) } func TestSession_Close(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j, _ := jid.NewWithString("jackal.im", true) tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) + _ = sess.Open(context.Background(), nil) tr.wrBuf.Reset() - sess.Close() + _ = sess.Close(context.Background()) require.Equal(t, "", tr.wrBuf.String()) - - tr = newFakeTransport(transport.WebSocket) - sess = New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) - tr.wrBuf.Reset() - - sess.Close() - require.Equal(t, ``, tr.wrBuf.String()) } func TestSession_Send(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j, _ := jid.NewWithString("ortuman@jackal.im/res", true) tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) elem := xmpp.NewElementNamespace("open", "urn:ietf:params:xml:ns:xmpp-framing") - sess.Open(nil) + + _ = sess.Open(context.Background(), nil) tr.wrBuf.Reset() - sess.Send(elem) + _ = sess.Send(context.Background(), elem) require.Equal(t, elem.String(), tr.wrBuf.String()) } func TestSession_Receive(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j, _ := jid.NewWithString("ortuman@jackal.im/res", true) - tr := newFakeTransport(transport.WebSocket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) + tr := newFakeTransport(transport.Socket) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) + + _ = sess.Open(context.Background(), nil) _, err := sess.Receive() require.Equal(t, &Error{}, err) - tr = newFakeTransport(transport.WebSocket) - sess = New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) - open := xmpp.NewElementNamespace("open", "") - open.ToXML(tr.rdBuf, true) + tr = newFakeTransport(transport.Socket) + sess = New(uuid.New(), &Config{JID: j}, tr, hosts) + + _ = sess.Open(context.Background(), nil) + open := xmpp.NewElementNamespace("stream:stream", "") + _ = open.ToXML(tr.rdBuf, false) _, err = sess.Receive() require.Equal(t, &Error{UnderlyingErr: streamerror.ErrInvalidNamespace}, err) - tr = newFakeTransport(transport.WebSocket) - sess = New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) - open.SetNamespace("urn:ietf:params:xml:ns:xmpp-framing") + tr = newFakeTransport(transport.Socket) + sess = New(uuid.New(), &Config{JID: j}, tr, hosts) + + _ = sess.Open(context.Background(), nil) + open.SetNamespace(jabberClientNamespace) + open.SetAttribute("xmlns:stream", streamNamespace) open.SetVersion("1.0") - open.ToXML(tr.rdBuf, true) + _ = open.ToXML(tr.rdBuf, false) iq := xmpp.NewIQType(uuid.New(), xmpp.ResultType) - iq.ToXML(tr.rdBuf, true) + _ = iq.ToXML(tr.rdBuf, true) _, err = sess.Receive() // read open stream element... st, err := sess.Receive() // read IQ... require.Nil(t, err) require.Equal(t, "iq", st.Name()) - tr = newFakeTransport(transport.WebSocket) - sess = New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) - open.ToXML(tr.rdBuf, true) + tr = newFakeTransport(transport.Socket) + sess = New(uuid.New(), &Config{JID: j}, tr, hosts) + + _ = sess.Open(context.Background(), nil) + _ = open.ToXML(tr.rdBuf, false) // bad stanza - xmpp.NewElementName("iq").ToXML(tr.rdBuf, true) + _ = xmpp.NewElementName("iq").ToXML(tr.rdBuf, true) _, err = sess.Receive() // read open stream element... _, err = sess.Receive() require.NotNil(t, err) @@ -186,8 +172,7 @@ func TestSession_Receive(t *testing.T) { } func TestSession_IsValidNamespace(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") iqClient := xmpp.NewElementNamespace("iq", "jabber:client") iqServer := xmpp.NewElementNamespace("iq", "jabber:server") @@ -195,28 +180,30 @@ func TestSession_IsValidNamespace(t *testing.T) { j, _ := jid.NewWithString("jackal.im", true) tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) + + _ = sess.Open(context.Background(), nil) require.Nil(t, sess.validateNamespace(iqClient)) require.Equal(t, &Error{UnderlyingErr: streamerror.ErrInvalidNamespace}, sess.validateNamespace(iqServer)) tr = newFakeTransport(transport.Socket) - sess = New(uuid.New(), &Config{JID: j, Transport: tr, IsServer: true}, r) - sess.Open(nil) + sess = New(uuid.New(), &Config{JID: j, IsServer: true}, tr, hosts) + + _ = sess.Open(context.Background(), nil) require.Equal(t, &Error{UnderlyingErr: streamerror.ErrInvalidNamespace}, sess.validateNamespace(iqClient)) require.Nil(t, sess.validateNamespace(iqServer)) } func TestSession_IsValidFrom(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j1, _ := jid.NewWithString("jackal.im", true) // server domain j2, _ := jid.NewWithString("ortuman@jackal.im/resource", true) // full jid with user tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j2, Transport: tr}, r) - sess.Open(nil) + sess := New(uuid.New(), &Config{JID: j2}, tr, hosts) + + _ = sess.Open(context.Background(), nil) sess.SetJID(j1) require.False(t, sess.isValidFrom("romeo@jackal.im")) @@ -225,21 +212,20 @@ func TestSession_IsValidFrom(t *testing.T) { } func TestSession_ValidateStream(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j, _ := jid.NewWithString("jackal.im", true) // server domain elem1 := xmpp.NewElementNamespace("stream:stream", "") elem2 := xmpp.NewElementNamespace("stream:stream", "jabber:client") elem4 := xmpp.NewElementNamespace("open", "") - elem5 := xmpp.NewElementNamespace("open", "urn:ietf:params:xml:ns:xmpp-framing") // try socket tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) err := sess.validateStreamElement(elem1) - sess.Open(nil) + + _ = sess.Open(context.Background(), nil) require.NotNil(t, err) require.Equal(t, streamerror.ErrInvalidNamespace, err.UnderlyingErr) @@ -264,37 +250,10 @@ func TestSession_ValidateStream(t *testing.T) { elem2.SetTo("jackal.im") require.Nil(t, sess.validateStreamElement(elem2)) - - // try websocket - tr = newFakeTransport(transport.WebSocket) - sess = New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) - err = sess.validateStreamElement(elem4) - require.NotNil(t, err) - require.Equal(t, streamerror.ErrInvalidNamespace, err.UnderlyingErr) - - err = sess.validateStreamElement(elem1) - require.NotNil(t, err) - require.Equal(t, streamerror.ErrUnsupportedStanzaType, err.UnderlyingErr) - - err = sess.validateStreamElement(elem5) - require.NotNil(t, err) - require.Equal(t, streamerror.ErrUnsupportedVersion, err.UnderlyingErr) - - elem5.SetVersion("1.0") - elem5.SetTo("example.org") - - err = sess.validateStreamElement(elem5) - require.NotNil(t, err) - require.Equal(t, streamerror.ErrHostUnknown, err.UnderlyingErr) - - elem5.SetTo("jackal.im") - require.Nil(t, sess.validateStreamElement(elem5)) } func TestSession_ExtractAddresses(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j1, _ := jid.NewWithString("jackal.im", true) j2, _ := jid.NewWithString("ortuman@jackal.im/res", true) @@ -304,8 +263,9 @@ func TestSession_ExtractAddresses(t *testing.T) { iq.SetTo("romeo@example.org") tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j1, Transport: tr}, r) - sess.Open(nil) + sess := New(uuid.New(), &Config{JID: j1}, tr, hosts) + + _ = sess.Open(context.Background(), nil) from, to, err := sess.extractAddresses(iq) require.Nil(t, err) require.Equal(t, "jackal.im", from.String()) @@ -333,13 +293,13 @@ func TestSession_ExtractAddresses(t *testing.T) { } func TestSession_BuildStanza(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j, _ := jid.NewWithString("ortuman@jackal.im/res", true) tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) + + _ = sess.Open(context.Background(), nil) elem := xmpp.NewElementNamespace("n", "ns") _, err := sess.buildStanza(elem) @@ -387,13 +347,13 @@ func TestSession_BuildStanza(t *testing.T) { } func TestSession_MapError(t *testing.T) { - r, _, shutdown := setupTest("jackal.im") - defer shutdown() + hosts := setupTest("jackal.im") j, _ := jid.NewWithString("ortuman@jackal.im/res", true) tr := newFakeTransport(transport.Socket) - sess := New(uuid.New(), &Config{JID: j, Transport: tr}, r) - sess.Open(nil) + sess := New(uuid.New(), &Config{JID: j}, tr, hosts) + + _ = sess.Open(context.Background(), nil) require.Equal(t, &Error{}, sess.mapErrorToSessionError(nil)) require.Equal(t, &Error{}, sess.mapErrorToSessionError(io.EOF)) @@ -407,13 +367,7 @@ func TestSession_MapError(t *testing.T) { require.Equal(t, &Error{UnderlyingErr: er}, sess.mapErrorToSessionError(er)) } -func setupTest(domain string) (*router.Router, *memstorage.Storage, func()) { - r, _ := router.New(&router.Config{ - Hosts: []router.HostConfig{{Name: domain, Certificate: tls.Certificate{}}}, - }) - s := memstorage.New() - storage.Set(s) - return r, s, func() { - storage.Unset() - } +func setupTest(domain string) *host.Hosts { + hosts, _ := host.New([]host.Config{{Name: domain, Certificate: tls.Certificate{}}}) + return hosts } diff --git a/sql/mysql.down.sql b/sql/mysql.down.sql old mode 100644 new mode 100755 index 873135dc4..7c72b57d3 --- a/sql/mysql.down.sql +++ b/sql/mysql.down.sql @@ -3,6 +3,17 @@ * See the LICENSE file for more information. */ +DROP TABLE IF EXISTS rooms_invites; +DROP TABLE IF EXISTS rooms_users; +DROP TABLE IF EXISTS rooms_config; +DROP TABLE IF EXISTS rooms; +DROP TABLE IF EXISTS resources; +DROP TABLE IF EXISTS occupants; +DROP TABLE IF EXISTS pubsub_items; +DROP TABLE IF EXISTS pubsub_subscriptions; +DROP TABLE IF EXISTS pubsub_affiliations; +DROP TABLE IF EXISTS pubsub_node_options; +DROP TABLE IF EXISTS pubsub_nodes; DROP TABLE IF EXISTS offline_messages; DROP TABLE IF EXISTS vcards; DROP TABLE IF EXISTS private_storage; @@ -11,4 +22,6 @@ DROP TABLE IF EXISTS roster_versions; DROP TABLE IF EXISTS roster_groups; DROP TABLE IF EXISTS roster_items; DROP TABLE IF EXISTS roster_notifications; +DROP TABLE IF EXISTS capabilities; +DROP TABLE IF EXISTS presences; DROP TABLE IF EXISTS users; diff --git a/sql/mysql.up.sql b/sql/mysql.up.sql old mode 100644 new mode 100755 index ee34a26be..e6e544613 --- a/sql/mysql.up.sql +++ b/sql/mysql.up.sql @@ -14,6 +14,40 @@ CREATE TABLE IF NOT EXISTS users ( created_at DATETIME NOT NULL ) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +-- presences + +CREATE TABLE IF NOT EXISTS presences ( + username VARCHAR(256) NOT NULL, + domain VARCHAR(256) NOT NULL, + resource VARCHAR(256) NOT NULL, + presence TEXT NOT NULL, + node VARCHAR(256) NOT NULL, + ver VARCHAR(256) NOT NULL, + allocation_id VARCHAR(256) NOT NULL, + updated_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + + PRIMARY KEY (username, domain, resource), + + INDEX i_presences_username_domain(username, domain), + INDEX i_presences_domain_resource(domain, resource), + INDEX i_presences_allocation_id(allocation_id) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- capabilities + +CREATE TABLE IF NOT EXISTS capabilities ( + node VARCHAR(256) NOT NULL, + ver VARCHAR(256) NOT NULL, + features TEXT, + updated_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + + PRIMARY KEY (node, ver) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + -- roster_notifications CREATE TABLE IF NOT EXISTS roster_notifications ( @@ -118,3 +152,147 @@ CREATE TABLE IF NOT EXISTS offline_messages ( INDEX i_offline_messages_username (username) ) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- pubsub_nodes + +CREATE TABLE IF NOT EXISTS pubsub_nodes ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + host TEXT NOT NULL, + name TEXT NOT NULL, + updated_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + + INDEX i_pubsub_nodes_host (host(256)), + UNIQUE INDEX i_pubsub_nodes_host_name (host(256), name(512)) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- pubsub_node_options + +CREATE TABLE IF NOT EXISTS pubsub_node_options ( + node_id BIGINT NOT NULL, + name TEXT NOT NULL, + value TEXT NOT NULL, + updated_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + + INDEX i_pubsub_node_options_node_id (node_id) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- pubsub_affiliations + +CREATE TABLE IF NOT EXISTS pubsub_affiliations ( + node_id BIGINT NOT NULL, + jid TEXT NOT NULL, + affiliation TEXT NOT NULL, + updated_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + + INDEX i_pubsub_affiliations_jid (jid(512)), + UNIQUE INDEX i_pubsub_affiliations_node_id_jid (node_id, jid(512)) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- pubsub_subscriptions + +CREATE TABLE IF NOT EXISTS pubsub_subscriptions ( + node_id BIGINT NOT NULL, + subid TEXT NOT NULL, + jid TEXT NOT NULL, + subscription TEXT NOT NULL, + updated_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + + INDEX i_pubsub_subscriptions_jid (jid(512)), + UNIQUE INDEX i_pubsub_subscriptions_node_id_jid (node_id, jid(512)) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- pubsub_items + +CREATE TABLE IF NOT EXISTS pubsub_items ( + node_id BIGINT NOT NULL, + item_id TEXT NOT NULL, + payload TEXT NOT NULL, + publisher TEXT NOT NULL, + updated_at DATETIME NOT NULL, + created_at DATETIME NOT NULL, + + INDEX i_pubsub_items_item_id (item_id(36)), + INDEX i_pubsub_items_node_id_created_at (node_id, created_at), + UNIQUE INDEX i_pubsub_items_node_id_item_id (node_id, item_id(36)) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- xep0045_occupants + +CREATE TABLE IF NOT EXISTS occupants ( + occupant_jid VARCHAR(512) PRIMARY KEY, + bare_jid VARCHAR(512) NOT NULL, + affiliation VARCHAR(32), + role VARCHAR(32) +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- xep0045_occupants_resources + +CREATE TABLE IF NOT EXISTS resources ( + occupant_jid VARCHAR(512) NOT NULL, + resource VARCHAR(256) NOT NULL, + + PRIMARY KEY (occupant_jid, resource), + + INDEX i_occupant_jid(occupant_jid) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- xep0045_rooms + +CREATE TABLE IF NOT EXISTS rooms ( + room_jid VARCHAR(256) PRIMARY KEY, + name TEXT, + description TEXT, + subject TEXT, + language TEXT, + locked BOOL NOT NULL, + occupants_online INT NOT NULL +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- xep0045_rooms_configurations + +CREATE TABLE IF NOT EXISTS rooms_config ( + room_jid VARCHAR(256) PRIMARY KEY, + public BOOL NOT NULL, + persistent BOOL NOT NULL, + pwd_protected BOOL NOT NULL, + password TEXT NOT NULL, + open BOOL NOT NULL, + moderated BOOL NOT NULL, + allow_invites BOOL NOT NULL, + max_occupants INT NOT NULL, + allow_subj_change BOOL NOT NULL, + non_anonymous BOOL NOT NULL, + can_send_pm VARCHAR(32) NOT NULL, + can_get_member_list VARCHAR(32) NOT NULL +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- xep0045_rooms_mapping_user_to_occupant_jids + +CREATE TABLE IF NOT EXISTS rooms_users ( + room_jid VARCHAR(256) NOT NULL, + user_jid VARCHAR(256) NOT NULL, + occupant_jid VARCHAR(512) NOT NULL, + + PRIMARY KEY(room_jid, user_jid) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- xep0045_rooms_invited_users + +CREATE TABLE IF NOT EXISTS rooms_invites ( + room_jid VARCHAR(256) NOT NULL, + user_jid VARCHAR(256) NOT NULL, + + PRIMARY KEY(room_jid, user_jid) + +) ENGINE=InnoDB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; diff --git a/sql/postgres.down.psql b/sql/postgres.down.psql old mode 100644 new mode 100755 index d0c486658..a2707def0 --- a/sql/postgres.down.psql +++ b/sql/postgres.down.psql @@ -3,13 +3,25 @@ * See the LICENSE file for more information. */ - DROP TABLE IF EXISTS offline_messages; - DROP TABLE IF EXISTS vcards; - DROP TABLE IF EXISTS private_storage; - DROP TABLE IF EXISTS blocklist_items; - DROP TABLE IF EXISTS roster_versions; - DROP TABLE IF EXISTS roster_groups; - DROP TABLE IF EXISTS roster_items; - DROP TABLE IF EXISTS roster_notifications; - DROP TABLE IF EXISTS users; - \ No newline at end of file +DROP TABLE IF EXISTS rooms_invites; +DROP TABLE IF EXISTS rooms_users; +DROP TABLE IF EXISTS rooms_config; +DROP TABLE IF EXISTS rooms; +DROP TABLE IF EXISTS resources; +DROP TABLE IF EXISTS occupants; +DROP TABLE IF EXISTS pubsub_items; +DROP TABLE IF EXISTS pubsub_subscriptions; +DROP TABLE IF EXISTS pubsub_affiliations; +DROP TABLE IF EXISTS pubsub_node_options; +DROP TABLE IF EXISTS pubsub_nodes; +DROP TABLE IF EXISTS offline_messages; +DROP TABLE IF EXISTS vcards; +DROP TABLE IF EXISTS private_storage; +DROP TABLE IF EXISTS blocklist_items; +DROP TABLE IF EXISTS roster_versions; +DROP TABLE IF EXISTS roster_groups; +DROP TABLE IF EXISTS roster_items; +DROP TABLE IF EXISTS roster_notifications; +DROP TABLE IF EXISTS capabilities; +DROP TABLE IF EXISTS presences; +DROP TABLE IF EXISTS users; diff --git a/sql/postgres.up.psql b/sql/postgres.up.psql old mode 100644 new mode 100755 index 36dd0d236..e2c24d053 --- a/sql/postgres.up.psql +++ b/sql/postgres.up.psql @@ -45,6 +45,42 @@ CREATE TABLE IF NOT EXISTS users ( SELECT enable_updated_at('users'); +-- presences + +CREATE TABLE IF NOT EXISTS presences ( + username VARCHAR(1023) NOT NULL, + domain VARCHAR(1023) NOT NULL, + resource VARCHAR(1023) NOT NULL, + presence TEXT NOT NULL, + node VARCHAR(1023) NOT NULL, + ver VARCHAR(1023) NOT NULL, + allocation_id VARCHAR(1023) NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + + PRIMARY KEY (username, domain, resource) +); + +SELECT enable_updated_at('presences'); + +CREATE INDEX IF NOT EXISTS i_presences_username_domain ON presences(username, domain); +CREATE INDEX IF NOT EXISTS i_presences_domain_resource ON presences(domain, resource); +CREATE INDEX IF NOT EXISTS i_presences_allocation_id ON presences(allocation_id); + +-- capabilities + +CREATE TABLE IF NOT EXISTS capabilities ( + node VARCHAR(1023) NOT NULL, + ver VARCHAR(1023) NOT NULL, + features TEXT NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + + PRIMARY KEY (node, ver) +); + +SELECT enable_updated_at('capabilities'); + -- roster_notifications CREATE TABLE IF NOT EXISTS roster_notifications ( @@ -149,3 +185,161 @@ CREATE TABLE IF NOT EXISTS offline_messages ( ); CREATE INDEX IF NOT EXISTS i_offline_messages_username ON offline_messages(username); + +-- pubsub_nodes + +CREATE TABLE IF NOT EXISTS pubsub_nodes ( + id BIGSERIAL, + host TEXT NOT NULL, + name TEXT NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + + PRIMARY KEY (id) +); + +CREATE UNIQUE INDEX IF NOT EXISTS i_pubsub_nodes_host_name ON pubsub_nodes(host, name); + +SELECT enable_updated_at('pubsub_nodes'); + +-- pubsub_node_options + +CREATE TABLE IF NOT EXISTS pubsub_node_options ( + node_id BIGINT NOT NULL, + name TEXT NOT NULL, + value TEXT NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS i_pubsub_node_options_node_id ON pubsub_node_options(node_id); + +SELECT enable_updated_at('pubsub_node_options'); + +-- pubsub_affiliations + +CREATE TABLE IF NOT EXISTS pubsub_affiliations ( + node_id BIGINT NOT NULL, + jid TEXT NOT NULL, + affiliation TEXT NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS i_pubsub_affiliations_jid ON pubsub_affiliations(jid); + +CREATE UNIQUE INDEX IF NOT EXISTS i_pubsub_affiliations_node_id_jid ON pubsub_affiliations(node_id, jid); + +SELECT enable_updated_at('pubsub_affiliations'); + +-- pubsub_subscriptions + +CREATE TABLE IF NOT EXISTS pubsub_subscriptions ( + node_id BIGINT NOT NULL, + subid TEXT NOT NULL, + jid TEXT NOT NULL, + subscription TEXT NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS i_pubsub_subscriptions_jid ON pubsub_subscriptions(jid); + +CREATE UNIQUE INDEX IF NOT EXISTS i_pubsub_subscriptions_node_id_jid ON pubsub_subscriptions(node_id, jid); + +SELECT enable_updated_at('pubsub_subscriptions'); + +-- pubsub_items + +CREATE TABLE IF NOT EXISTS pubsub_items ( + node_id BIGINT NOT NULL, + item_id TEXT NOT NULL, + payload TEXT NOT NULL, + publisher TEXT NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS i_pubsub_items_item_id ON pubsub_items(item_id); + +CREATE UNIQUE INDEX IF NOT EXISTS i_pubsub_items_node_id_item_id ON pubsub_items(node_id, item_id); + +SELECT enable_updated_at('pubsub_items'); + +-- xep0045_occupants + +CREATE TABLE IF NOT EXISTS occupants ( + occupant_jid VARCHAR(1023) PRIMARY KEY, + bare_jid VARCHAR(1023) NOT NULL, + affiliation VARCHAR(32), + role VARCHAR(32) +); + +-- xep0045_occupants_resources + +CREATE TABLE IF NOT EXISTS resources ( + occupant_jid VARCHAR(512) NOT NULL, + resource VARCHAR(256) NOT NULL, + + PRIMARY KEY (occupant_jid, resource), + +); + +CREATE INDEX IF NOT EXISTS i_occupant_jid ON resources(occupant_jid); + +-- xep0045_rooms + +CREATE TABLE IF NOT EXISTS rooms ( + room_jid VARCHAR(1023) PRIMARY KEY, + name TEXT, + description TEXT, + subject TEXT, + language TEXT, + locked BOOL NOT NULL, + occupants_online INT NOT NULL +); + +-- xep0045_rooms_configurations + +CREATE TABLE IF NOT EXISTS rooms_config ( + room_jid VARCHAR(1023) PRIMARY KEY, + public BOOL NOT NULL, + persistent BOOL NOT NULL, + pwd_protected BOOL NOT NULL, + password TEXT NOT NULL, + open BOOL NOT NULL, + moderated BOOL NOT NULL, + allow_invites BOOL NOT NULL, + max_occupants INT NOT NULL, + allow_subj_change BOOL NOT NULL, + non_anonymous BOOL NOT NULL, + can_send_pm VARCHAR(32) NOT NULL, + can_get_member_list VARCHAR(32) NOT NULL +); + +-- xep0045_rooms_mapping_user_to_occupant_jids + +CREATE TABLE IF NOT EXISTS rooms_users ( + room_jid VARCHAR(512) NOT NULL, + user_jid VARCHAR(512) NOT NULL, + occupant_jid VARCHAR(512) NOT NULL, + + PRIMARY KEY(room_jid, user_jid) + +); + +CREATE INDEX IF NOT EXISTS i_room_jid_users ON rooms_users(room_jid); + +-- xep0045_rooms_invited_users + +CREATE TABLE IF NOT EXISTS rooms_invites ( + room_jid VARCHAR(512) NOT NULL, + user_jid VARCHAR(512) NOT NULL, + + PRIMARY KEY(room_jid, user_jid) + + INDEX i_room_jid_invites(room_jid) + +); + +CREATE INDEX IF NOT EXISTS i_room_jid_invites ON rooms_invites(room_jid); diff --git a/storage/badgerdb/badgerdb.go b/storage/badgerdb/badgerdb.go deleted file mode 100644 index c70a95f14..000000000 --- a/storage/badgerdb/badgerdb.go +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "errors" - "fmt" - "os" - "path/filepath" - "reflect" - "time" - - "github.com/dgraph-io/badger" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/model/serializer" - "github.com/ortuman/jackal/pool" -) - -var ( - errBadgerDBWrongEntityType = errors.New("badgerdb: wrong entity type") - errBadgerDBEntityNotFound = errors.New("badgerdb: entity not found") -) - -// Storage represents a BadgerDB storage sub system. -type Storage struct { - db *badger.DB - pool *pool.BufferPool - doneCh chan chan bool -} - -// New returns a new BadgerDB storage instance. -func New(cfg *Config) *Storage { - b := &Storage{ - pool: pool.NewBufferPool(), - doneCh: make(chan chan bool), - } - if err := os.MkdirAll(filepath.Dir(cfg.DataDir), os.ModePerm); err != nil { - log.Fatalf("%v", err) - } - opts := badger.DefaultOptions - opts.Dir = cfg.DataDir - opts.ValueDir = cfg.DataDir - db, err := badger.Open(opts) - if err != nil { - log.Fatalf("%v", err) - } - b.db = db - go b.loop() - return b -} - -// IsClusterCompatible returns whether or not the underlying storage subsystem can be used in cluster mode. -func (b *Storage) IsClusterCompatible() bool { return false } - -// Close shuts down BadgerDB storage sub system. -func (b *Storage) Close() error { - ch := make(chan bool) - b.doneCh <- ch - <-ch - return nil -} - -func (b *Storage) loop() { - tc := time.NewTicker(time.Minute) - defer tc.Stop() - for { - select { - case <-tc.C: - if err := b.db.RunValueLogGC(0.5); err != nil { - log.Warnf("%s", err) - } - case ch := <-b.doneCh: - if err := b.db.Close(); err != nil { - log.Warnf("%s", err) - } - close(ch) - return - } - } -} - -func (b *Storage) insertOrUpdate(entity interface{}, key []byte, tx *badger.Txn) error { - gs, ok := entity.(serializer.Serializer) - if !ok { - return fmt.Errorf("%v: %T", errBadgerDBWrongEntityType, entity) - } - bts, err := serializer.Serialize(gs) - if err != nil { - return err - } - val := make([]byte, len(bts)) - copy(val, bts) - return tx.Set(key, val) -} - -func (b *Storage) delete(key []byte, txn *badger.Txn) error { - return txn.Delete(key) -} - -func (b *Storage) deletePrefix(prefix []byte, txn *badger.Txn) error { - var keys [][]byte - if err := b.forEachKey(prefix, func(key []byte) error { - keys = append(keys, key) - return nil - }); err != nil { - return err - } - for _, k := range keys { - if err := txn.Delete(k); err != nil { - return err - } - } - return nil -} - -func (b *Storage) fetch(entity interface{}, key []byte) error { - return b.db.View(func(tx *badger.Txn) error { - val, err := b.getVal(key, tx) - if err != nil { - return err - } - if val != nil { - if entity != nil { - gd, ok := entity.(serializer.Deserializer) - if !ok { - return fmt.Errorf("%v: %T", errBadgerDBWrongEntityType, entity) - } - return serializer.Deserialize(val, gd) - } - return nil - } - return errBadgerDBEntityNotFound - }) -} - -func (b *Storage) fetchAll(v interface{}, prefix []byte) error { - t := reflect.TypeOf(v).Elem() - if t.Kind() != reflect.Slice { - return fmt.Errorf("%v: %T", errBadgerDBWrongEntityType, v) - } - s := reflect.ValueOf(v).Elem() - return b.forEachKeyAndValue(prefix, func(k, val []byte) error { - e := reflect.New(t.Elem()).Elem() - i := e.Addr().Interface() - gd, ok := i.(serializer.Deserializer) - if !ok { - return fmt.Errorf("%v: %T", errBadgerDBWrongEntityType, i) - } - if err := serializer.Deserialize(val, gd); err != nil { - return err - } - s.Set(reflect.Append(s, e)) - return nil - }) -} - -func (b *Storage) getVal(key []byte, txn *badger.Txn) ([]byte, error) { - item, err := txn.Get(key) - switch err { - case nil: - break - case badger.ErrKeyNotFound: - return nil, nil - default: - return nil, err - } - return item.ValueCopy(nil) -} - -func (b *Storage) forEachKey(prefix []byte, f func(k []byte) error) error { - return b.db.View(func(txn *badger.Txn) error { - opts := badger.DefaultIteratorOptions - opts.PrefetchValues = false - iter := txn.NewIterator(opts) - defer iter.Close() - - for iter.Seek(prefix); iter.ValidForPrefix(prefix); iter.Next() { - it := iter.Item() - if err := f(it.Key()); err != nil { - return err - } - } - return nil - }) -} - -func (b *Storage) forEachKeyAndValue(prefix []byte, f func(k, v []byte) error) error { - return b.db.View(func(txn *badger.Txn) error { - iter := txn.NewIterator(badger.DefaultIteratorOptions) - defer iter.Close() - - for iter.Seek(prefix); iter.ValidForPrefix(prefix); iter.Next() { - it := iter.Item() - val, err := it.ValueCopy(nil) - if err != nil { - return err - } - if err := f(it.Key(), val); err != nil { - return err - } - } - return nil - }) -} diff --git a/storage/badgerdb/badgerdb_test.go b/storage/badgerdb/badgerdb_test.go deleted file mode 100644 index 79ac087b1..000000000 --- a/storage/badgerdb/badgerdb_test.go +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "io/ioutil" - "os" - - "github.com/pborman/uuid" -) - -type testBadgerDBHelper struct { - db *Storage - dataDir string -} - -func tUtilBadgerDBSetup() *testBadgerDBHelper { - h := &testBadgerDBHelper{} - dir, _ := ioutil.TempDir("", "") - h.dataDir = dir + "/com.jackal.tests.badgerdb." + uuid.New() - cfg := Config{DataDir: h.dataDir} - h.db = New(&cfg) - return h -} - -func tUtilBadgerDBTeardown(h *testBadgerDBHelper) { - _ = h.db.Close() - _ = os.RemoveAll(h.dataDir) -} diff --git a/storage/badgerdb/block_list.go b/storage/badgerdb/block_list.go deleted file mode 100644 index b4d705a81..000000000 --- a/storage/badgerdb/block_list.go +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "github.com/dgraph-io/badger" - "github.com/ortuman/jackal/model" -) - -// InsertBlockListItems inserts a set of block list item entities -// into storage, only in case they haven't been previously inserted. -func (b *Storage) InsertBlockListItems(items []model.BlockListItem) error { - return b.db.Update(func(tx *badger.Txn) error { - for _, item := range items { - if err := b.insertOrUpdate(&item, b.blockListItemKey(item.Username, item.JID), tx); err != nil { - return err - } - } - return nil - }) -} - -// DeleteBlockListItems deletes a set of block list item entities from storage. -func (b *Storage) DeleteBlockListItems(items []model.BlockListItem) error { - return b.db.Update(func(tx *badger.Txn) error { - for _, item := range items { - if err := b.delete(b.blockListItemKey(item.Username, item.JID), tx); err != nil { - return err - } - } - return nil - }) -} - -// FetchBlockListItems retrieves from storage all block list item entities -// associated to a given user. -func (b *Storage) FetchBlockListItems(username string) ([]model.BlockListItem, error) { - var blItems []model.BlockListItem - if err := b.fetchAll(&blItems, []byte("blockListItems:"+username)); err != nil { - return nil, err - } - return blItems, nil -} - -func (b *Storage) blockListItemKey(username, jid string) []byte { - return []byte("blockListItems:" + username + ":" + jid) -} diff --git a/storage/badgerdb/block_list_test.go b/storage/badgerdb/block_list_test.go deleted file mode 100644 index c6b35fe4c..000000000 --- a/storage/badgerdb/block_list_test.go +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "sort" - "testing" - - "github.com/ortuman/jackal/model" - "github.com/stretchr/testify/require" -) - -func TestBadgerDB_BlockListItems(t *testing.T) { - t.Parallel() - - h := tUtilBadgerDBSetup() - defer tUtilBadgerDBTeardown(h) - - items := []model.BlockListItem{ - {Username: "ortuman", JID: "juliet@jackal.im"}, - {Username: "ortuman", JID: "user@jackal.im"}, - {Username: "ortuman", JID: "romeo@jackal.im"}, - } - sort.Slice(items, func(i, j int) bool { return items[i].JID < items[j].JID }) - - err := h.db.InsertBlockListItems(items) - require.Nil(t, err) - - sItems, err := h.db.FetchBlockListItems("ortuman") - sort.Slice(sItems, func(i, j int) bool { return sItems[i].JID < sItems[j].JID }) - require.Nil(t, err) - require.Equal(t, items, sItems) - - items = append(items[:1], items[2:]...) - h.db.DeleteBlockListItems([]model.BlockListItem{{Username: "ortuman", JID: "romeo@jackal.im"}}) - - sItems, err = h.db.FetchBlockListItems("ortuman") - sort.Slice(items, func(i, j int) bool { return items[i].JID < items[j].JID }) - require.Nil(t, err) - require.Equal(t, items, sItems) - - err = h.db.DeleteBlockListItems(items) - require.Nil(t, err) - sItems, _ = h.db.FetchBlockListItems("ortuman") - require.Equal(t, 0, len(sItems)) -} diff --git a/storage/badgerdb/config.go b/storage/badgerdb/config.go deleted file mode 100644 index f534c7f01..000000000 --- a/storage/badgerdb/config.go +++ /dev/null @@ -1,24 +0,0 @@ -package badgerdb - -// Config represents BadgerDB storage configuration. -type Config struct { - DataDir string `yaml:"data_dir"` -} - -// DefaultDataDir is the default directory for BadgerDB storage -const DefaultDataDir = "./data" - -// UnmarshalYAML satisfies Unmarshaler interface -func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { - type rawConfig Config - - parsed := rawConfig{DataDir: DefaultDataDir} - - if err := unmarshal(&parsed); err != nil { - return err - } - - *c = Config(parsed) - - return nil -} diff --git a/storage/badgerdb/offline.go b/storage/badgerdb/offline.go deleted file mode 100644 index 5c129fd7e..000000000 --- a/storage/badgerdb/offline.go +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "github.com/dgraph-io/badger" - "github.com/ortuman/jackal/xmpp" -) - -// InsertOfflineMessage inserts a new message element into -// user's offline queue. -func (b *Storage) InsertOfflineMessage(message *xmpp.Message, username string) error { - return b.db.Update(func(tx *badger.Txn) error { - return b.insertOrUpdate(message, b.offlineMessageKey(username, message.ID()), tx) - }) -} - -// CountOfflineMessages returns current length of user's offline queue. -func (b *Storage) CountOfflineMessages(username string) (int, error) { - cnt := 0 - prefix := []byte("offlineMessages:" + username) - err := b.forEachKey(prefix, func(key []byte) error { - cnt++ - return nil - }) - return cnt, err -} - -// FetchOfflineMessages retrieves from storage current user offline queue. -func (b *Storage) FetchOfflineMessages(username string) ([]xmpp.Message, error) { - var msgs []xmpp.Message - if err := b.fetchAll(&msgs, []byte("offlineMessages:"+username)); err != nil { - return nil, err - } - switch len(msgs) { - case 0: - return nil, nil - default: - ret := make([]xmpp.Message, len(msgs)) - for i := 0; i < len(msgs); i++ { - ret[i] = msgs[i] - } - return ret, nil - } -} - -// DeleteOfflineMessages clears a user offline queue. -func (b *Storage) DeleteOfflineMessages(username string) error { - return b.db.Update(func(tx *badger.Txn) error { - return b.deletePrefix([]byte("offlineMessages:"+username), tx) - }) -} - -func (b *Storage) offlineMessageKey(username, identifier string) []byte { - return []byte("offlineMessages:" + username + ":" + identifier) -} diff --git a/storage/badgerdb/offline_test.go b/storage/badgerdb/offline_test.go deleted file mode 100644 index 1b1c7ce0b..000000000 --- a/storage/badgerdb/offline_test.go +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "testing" - - "github.com/ortuman/jackal/xmpp" - "github.com/pborman/uuid" - "github.com/stretchr/testify/require" -) - -func TestBadgerDB_OfflineMessages(t *testing.T) { - t.Parallel() - - h := tUtilBadgerDBSetup() - defer tUtilBadgerDBTeardown(h) - - msg1 := xmpp.NewMessageType(uuid.New(), xmpp.NormalType) - b1 := xmpp.NewElementName("body") - b1.SetText("Hi buddy!") - msg1.AppendElement(b1) - - msg2 := xmpp.NewMessageType(uuid.New(), xmpp.NormalType) - b2 := xmpp.NewElementName("body") - b2.SetText("what's up?!") - msg1.AppendElement(b1) - - require.NoError(t, h.db.InsertOfflineMessage(msg1, "ortuman")) - require.NoError(t, h.db.InsertOfflineMessage(msg2, "ortuman")) - - cnt, err := h.db.CountOfflineMessages("ortuman") - require.Nil(t, err) - require.Equal(t, 2, cnt) - - msgs, err := h.db.FetchOfflineMessages("ortuman") - require.Nil(t, err) - require.Equal(t, 2, len(msgs)) - - msgs2, err := h.db.FetchOfflineMessages("ortuman2") - require.Nil(t, err) - require.Equal(t, 0, len(msgs2)) - - require.NoError(t, h.db.DeleteOfflineMessages("ortuman")) - cnt, err = h.db.CountOfflineMessages("ortuman") - require.Nil(t, err) - require.Equal(t, 0, cnt) -} diff --git a/storage/badgerdb/private.go b/storage/badgerdb/private.go deleted file mode 100644 index 825dba905..000000000 --- a/storage/badgerdb/private.go +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "github.com/dgraph-io/badger" - "github.com/ortuman/jackal/xmpp" -) - -// InsertOrUpdatePrivateXML inserts a new private element into storage, -// or updates it in case it's been previously inserted. -func (b *Storage) InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace string, username string) error { - r := xmpp.NewElementName("r") - r.AppendElements(privateXML) - return b.db.Update(func(tx *badger.Txn) error { - return b.insertOrUpdate(r, b.privateStorageKey(username, namespace), tx) - }) -} - -// FetchPrivateXML retrieves from storage a private element. -func (b *Storage) FetchPrivateXML(namespace string, username string) ([]xmpp.XElement, error) { - var r xmpp.Element - err := b.fetch(&r, b.privateStorageKey(username, namespace)) - switch err { - case nil: - return r.Elements().All(), nil - case errBadgerDBEntityNotFound: - return nil, nil - default: - return nil, err - } -} - -func (b *Storage) privateStorageKey(username, namespace string) []byte { - return []byte("privateElements:" + username + ":" + namespace) -} diff --git a/storage/badgerdb/private_test.go b/storage/badgerdb/private_test.go deleted file mode 100644 index 3b13e2c9b..000000000 --- a/storage/badgerdb/private_test.go +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "testing" - - "github.com/ortuman/jackal/xmpp" - "github.com/stretchr/testify/require" -) - -func TestBadgerDB_PrivateXML(t *testing.T) { - t.Parallel() - - h := tUtilBadgerDBSetup() - defer tUtilBadgerDBTeardown(h) - - pv1 := xmpp.NewElementNamespace("ex1", "exodus:ns") - pv2 := xmpp.NewElementNamespace("ex2", "exodus:ns") - - require.NoError(t, h.db.InsertOrUpdatePrivateXML([]xmpp.XElement{pv1, pv2}, "exodus:ns", "ortuman")) - - prvs, err := h.db.FetchPrivateXML("exodus:ns", "ortuman") - require.Nil(t, err) - require.Equal(t, 2, len(prvs)) - - prvs2, err := h.db.FetchPrivateXML("exodus:ns", "ortuman2") - require.Nil(t, prvs2) - require.Nil(t, err) -} diff --git a/storage/badgerdb/roster.go b/storage/badgerdb/roster.go deleted file mode 100644 index 717687a26..000000000 --- a/storage/badgerdb/roster.go +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "github.com/dgraph-io/badger" - "github.com/ortuman/jackal/model/rostermodel" -) - -// InsertOrUpdateRosterItem inserts a new roster item entity into storage, -// or updates it in case it's been previously inserted. -func (b *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Version, error) { - if err := b.db.Update(func(tx *badger.Txn) error { - return b.insertOrUpdate(ri, b.rosterItemKey(ri.Username, ri.JID), tx) - }); err != nil { - return rostermodel.Version{}, err - } - return b.updateRosterVer(ri.Username, false) -} - -// DeleteRosterItem deletes a roster item entity from storage. -func (b *Storage) DeleteRosterItem(user, contact string) (rostermodel.Version, error) { - if err := b.db.Update(func(tx *badger.Txn) error { - return b.delete(b.rosterItemKey(user, contact), tx) - }); err != nil { - return rostermodel.Version{}, err - } - return b.updateRosterVer(user, true) -} - -// FetchRosterItems retrieves from storage all roster item entities -// associated to a given user. -func (b *Storage) FetchRosterItems(user string) ([]rostermodel.Item, rostermodel.Version, error) { - var ris []rostermodel.Item - if err := b.fetchAll(&ris, []byte("rosterItems:"+user)); err != nil { - return nil, rostermodel.Version{}, err - } - ver, err := b.fetchRosterVer(user) - return ris, ver, err -} - -// FetchRosterItemsInGroups retrieves from storage all roster item entities -// associated to a given user and a set of groups. -func (b *Storage) FetchRosterItemsInGroups(user string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { - groupSet := make(map[string]struct{}, len(groups)) - for _, group := range groups { - groupSet[group] = struct{}{} - } - // fetch all items - var ris []rostermodel.Item - if err := b.fetchAll(&ris, []byte("rosterItems:"+user)); err != nil { - return nil, rostermodel.Version{}, err - } - var res []rostermodel.Item - for _, ri := range ris { - for _, riGroup := range ri.Groups { - if _, ok := groupSet[riGroup]; ok { - res = append(res, ri) - break - } - } - } - ver, err := b.fetchRosterVer(user) - return res, ver, err -} - -// FetchRosterItem retrieves from storage a roster item entity. -func (b *Storage) FetchRosterItem(user, contact string) (*rostermodel.Item, error) { - var ri rostermodel.Item - err := b.fetch(&ri, b.rosterItemKey(user, contact)) - switch err { - case nil: - return &ri, nil - case errBadgerDBEntityNotFound: - return nil, nil - default: - return nil, err - } -} - -// InsertOrUpdateRosterNotification inserts a new roster notification entity -// into storage, or updates it in case it's been previously inserted. -func (b *Storage) InsertOrUpdateRosterNotification(rn *rostermodel.Notification) error { - return b.db.Update(func(tx *badger.Txn) error { - return b.insertOrUpdate(rn, b.rosterNotificationKey(rn.Contact, rn.JID), tx) - }) -} - -// DeleteRosterNotification deletes a roster notification entity from storage. -func (b *Storage) DeleteRosterNotification(contact, jid string) error { - return b.db.Update(func(tx *badger.Txn) error { - return b.delete(b.rosterNotificationKey(contact, jid), tx) - }) -} - -// FetchRosterNotification retrieves from storage a roster notification entity. -func (b *Storage) FetchRosterNotification(contact string, jid string) (*rostermodel.Notification, error) { - var rn rostermodel.Notification - err := b.fetch(&rn, b.rosterNotificationKey(contact, jid)) - switch err { - case nil: - return &rn, nil - case errBadgerDBEntityNotFound: - return nil, nil - default: - return nil, err - } -} - -// FetchRosterNotifications retrieves from storage all roster notifications -// associated to a given user. -func (b *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { - var rns []rostermodel.Notification - if err := b.fetchAll(&rns, []byte("rosterNotifications:"+contact)); err != nil { - return nil, err - } - return rns, nil -} - -func (b *Storage) updateRosterVer(username string, isDeletion bool) (rostermodel.Version, error) { - v, err := b.fetchRosterVer(username) - if err != nil { - return rostermodel.Version{}, err - } - v.Ver++ - if isDeletion { - v.DeletionVer = v.Ver - } - if err := b.db.Update(func(tx *badger.Txn) error { - return b.insertOrUpdate(&v, b.rosterVersionKey(username), tx) - }); err != nil { - return rostermodel.Version{}, err - } - return v, nil -} - -func (b *Storage) fetchRosterVer(username string) (rostermodel.Version, error) { - var ver rostermodel.Version - err := b.fetch(&ver, b.rosterVersionKey(username)) - switch err { - case nil, errBadgerDBEntityNotFound: - return ver, nil - default: - return ver, err - } -} - -func (b *Storage) rosterItemKey(user, contact string) []byte { - return []byte("rosterItems:" + user + ":" + contact) -} - -func (b *Storage) rosterVersionKey(username string) []byte { - return []byte("rosterVersions:" + username) -} - -func (b *Storage) rosterNotificationKey(contact, jid string) []byte { - return []byte("rosterNotifications:" + contact + ":" + jid) -} diff --git a/storage/badgerdb/roster_test.go b/storage/badgerdb/roster_test.go deleted file mode 100644 index 1c93aa8a6..000000000 --- a/storage/badgerdb/roster_test.go +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "testing" - - "github.com/ortuman/jackal/model/rostermodel" - "github.com/ortuman/jackal/xmpp" - "github.com/stretchr/testify/require" -) - -func TestBadgerDB_RosterItems(t *testing.T) { - t.Parallel() - - h := tUtilBadgerDBSetup() - defer tUtilBadgerDBTeardown(h) - - ri1 := &rostermodel.Item{ - Username: "ortuman", - JID: "juliet@jackal.im", - Subscription: "both", - Groups: []string{"general", "friends"}, - } - ri2 := &rostermodel.Item{ - Username: "ortuman", - JID: "romeo@jackal.im", - Subscription: "both", - Groups: []string{"general", "buddies"}, - } - ri3 := &rostermodel.Item{ - Username: "ortuman", - JID: "hamlet@jackal.im", - Subscription: "both", - Groups: []string{"family", "friends"}, - } - _, err := h.db.InsertOrUpdateRosterItem(ri1) - require.Nil(t, err) - _, err = h.db.InsertOrUpdateRosterItem(ri2) - require.Nil(t, err) - _, err = h.db.InsertOrUpdateRosterItem(ri3) - require.Nil(t, err) - - ris, _, err := h.db.FetchRosterItems("ortuman") - require.Nil(t, err) - require.Equal(t, 3, len(ris)) - - ris, _, err = h.db.FetchRosterItemsInGroups("ortuman", []string{"friends"}) - require.Nil(t, err) - require.Equal(t, 2, len(ris)) - - ris, _, err = h.db.FetchRosterItemsInGroups("ortuman", []string{"general"}) - require.Nil(t, err) - require.Equal(t, 2, len(ris)) - - ris, _, err = h.db.FetchRosterItemsInGroups("ortuman", []string{"buddies"}) - require.Nil(t, err) - require.Equal(t, 1, len(ris)) - - ris2, _, err := h.db.FetchRosterItems("ortuman2") - require.Nil(t, err) - require.Equal(t, 0, len(ris2)) - - ri4, err := h.db.FetchRosterItem("ortuman", "juliet@jackal.im") - require.Nil(t, err) - require.Equal(t, ri1, ri4) - - _, err = h.db.DeleteRosterItem("ortuman", "juliet@jackal.im") - require.NoError(t, err) - _, err = h.db.DeleteRosterItem("ortuman", "romeo@jackal.im") - require.NoError(t, err) - _, err = h.db.DeleteRosterItem("ortuman", "hamlet@jackal.im") - require.NoError(t, err) - - ris, _, err = h.db.FetchRosterItems("ortuman") - require.Nil(t, err) - require.Equal(t, 0, len(ris)) -} - -func TestBadgerDB_RosterNotifications(t *testing.T) { - t.Parallel() - - h := tUtilBadgerDBSetup() - defer tUtilBadgerDBTeardown(h) - - rn1 := rostermodel.Notification{ - Contact: "ortuman", - JID: "juliet@jackal.im", - Presence: &xmpp.Presence{}, - } - rn2 := rostermodel.Notification{ - Contact: "ortuman", - JID: "romeo@jackal.im", - Presence: &xmpp.Presence{}, - } - require.NoError(t, h.db.InsertOrUpdateRosterNotification(&rn1)) - require.NoError(t, h.db.InsertOrUpdateRosterNotification(&rn2)) - - rns, err := h.db.FetchRosterNotifications("ortuman") - require.Nil(t, err) - require.Equal(t, 2, len(rns)) - - rns2, err := h.db.FetchRosterNotifications("ortuman2") - require.Nil(t, err) - require.Equal(t, 0, len(rns2)) - - require.NoError(t, h.db.DeleteRosterNotification(rn1.Contact, rn1.JID)) - - rns, err = h.db.FetchRosterNotifications("ortuman") - require.Nil(t, err) - require.Equal(t, 1, len(rns)) - - require.NoError(t, h.db.DeleteRosterNotification(rn2.Contact, rn2.JID)) - - rns, err = h.db.FetchRosterNotifications("ortuman") - require.Nil(t, err) - require.Equal(t, 0, len(rns)) -} diff --git a/storage/badgerdb/user.go b/storage/badgerdb/user.go deleted file mode 100644 index e76892a86..000000000 --- a/storage/badgerdb/user.go +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "github.com/dgraph-io/badger" - "github.com/ortuman/jackal/model" -) - -// InsertOrUpdateUser inserts a new user entity into storage, -// or updates it in case it's been previously inserted. -func (b *Storage) InsertOrUpdateUser(user *model.User) error { - return b.db.Update(func(tx *badger.Txn) error { - return b.insertOrUpdate(user, b.userKey(user.Username), tx) - }) -} - -// DeleteUser deletes a user entity from storage. -func (b *Storage) DeleteUser(username string) error { - return b.db.Update(func(tx *badger.Txn) error { - return b.delete(b.userKey(username), tx) - }) -} - -// FetchUser retrieves from storage a user entity. -func (b *Storage) FetchUser(username string) (*model.User, error) { - var usr model.User - err := b.fetch(&usr, b.userKey(username)) - switch err { - case nil: - return &usr, nil - case errBadgerDBEntityNotFound: - return nil, nil - default: - return nil, err - } -} - -// UserExists returns whether or not a user exists within storage. -func (b *Storage) UserExists(username string) (bool, error) { - err := b.fetch(nil, b.userKey(username)) - switch err { - case nil: - return true, nil - case errBadgerDBEntityNotFound: - return false, nil - default: - return false, err - } -} - -func (b *Storage) userKey(username string) []byte { - return []byte("users:" + username) -} diff --git a/storage/badgerdb/user_test.go b/storage/badgerdb/user_test.go deleted file mode 100644 index 98e02bd09..000000000 --- a/storage/badgerdb/user_test.go +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "testing" - - "github.com/ortuman/jackal/model" - "github.com/stretchr/testify/require" -) - -func TestBadgerDB_User(t *testing.T) { - t.Parallel() - - h := tUtilBadgerDBSetup() - defer tUtilBadgerDBTeardown(h) - - usr := model.User{Username: "ortuman", Password: "1234"} - - err := h.db.InsertOrUpdateUser(&usr) - require.Nil(t, err) - - usr2, err := h.db.FetchUser("ortuman") - require.Nil(t, err) - require.Equal(t, "ortuman", usr2.Username) - require.Equal(t, "1234", usr2.Password) - - exists, err := h.db.UserExists("ortuman") - require.Nil(t, err) - require.True(t, exists) - - usr3, err := h.db.FetchUser("ortuman2") - require.Nil(t, usr3) - require.Nil(t, err) - - err = h.db.DeleteUser("ortuman") - require.Nil(t, err) - - exists, err = h.db.UserExists("ortuman") - require.Nil(t, err) - require.False(t, exists) -} diff --git a/storage/badgerdb/vcard.go b/storage/badgerdb/vcard.go deleted file mode 100644 index 3452f95ab..000000000 --- a/storage/badgerdb/vcard.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "github.com/dgraph-io/badger" - "github.com/ortuman/jackal/xmpp" -) - -// InsertOrUpdateVCard inserts a new vCard element into storage, -// or updates it in case it's been previously inserted. -func (b *Storage) InsertOrUpdateVCard(vCard xmpp.XElement, username string) error { - return b.db.Update(func(tx *badger.Txn) error { - return b.insertOrUpdate(vCard, b.vCardKey(username), tx) - }) -} - -// FetchVCard retrieves from storage a vCard element associated -// to a given user. -func (b *Storage) FetchVCard(username string) (xmpp.XElement, error) { - var vCard xmpp.Element - err := b.fetch(&vCard, b.vCardKey(username)) - switch err { - case nil: - return &vCard, nil - case errBadgerDBEntityNotFound: - return nil, nil - default: - return nil, err - } -} - -func (b *Storage) vCardKey(username string) []byte { - return []byte("vCards:" + username) -} diff --git a/storage/badgerdb/vcard_test.go b/storage/badgerdb/vcard_test.go deleted file mode 100644 index 4a6dd4764..000000000 --- a/storage/badgerdb/vcard_test.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package badgerdb - -import ( - "testing" - - "github.com/ortuman/jackal/xmpp" - "github.com/stretchr/testify/require" -) - -func TestBadgerDB_VCard(t *testing.T) { - t.Parallel() - - h := tUtilBadgerDBSetup() - defer tUtilBadgerDBTeardown(h) - - vcard := xmpp.NewElementNamespace("vCard", "vcard-temp") - fn := xmpp.NewElementName("FN") - fn.SetText("Miguel Ɓngel OrtuƱo") - vcard.AppendElement(fn) - - err := h.db.InsertOrUpdateVCard(vcard, "ortuman") - require.Nil(t, err) - - vcard2, err := h.db.FetchVCard("ortuman") - require.Nil(t, err) - require.Equal(t, "vCard", vcard2.Name()) - require.Equal(t, "vcard-temp", vcard2.Namespace()) - require.NotNil(t, vcard2.Elements().Child("FN")) - - vcard3, err := h.db.FetchVCard("ortuman2") - require.Nil(t, vcard3) - require.Nil(t, err) -} diff --git a/storage/block_list.go b/storage/block_list.go deleted file mode 100644 index e6a4db145..000000000 --- a/storage/block_list.go +++ /dev/null @@ -1,27 +0,0 @@ -package storage - -import "github.com/ortuman/jackal/model" - -// blockListStorage defines storage operations for user's block list -type blockListStorage interface { - InsertBlockListItems(items []model.BlockListItem) error - DeleteBlockListItems(items []model.BlockListItem) error - FetchBlockListItems(username string) ([]model.BlockListItem, error) -} - -// InsertBlockListItems inserts a set of block list item entities -// into storage, only in case they haven't been previously inserted. -func InsertBlockListItems(items []model.BlockListItem) error { - return instance().InsertBlockListItems(items) -} - -// DeleteBlockListItems deletes a set of block list item entities from storage. -func DeleteBlockListItems(items []model.BlockListItem) error { - return instance().DeleteBlockListItems(items) -} - -// FetchBlockListItems retrieves from storage all block list item entities -// associated to a given user. -func FetchBlockListItems(username string) ([]model.BlockListItem, error) { - return instance().FetchBlockListItems(username) -} diff --git a/storage/config.go b/storage/config.go index 3bcd53a7e..5e4db4e75 100644 --- a/storage/config.go +++ b/storage/config.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" - "github.com/ortuman/jackal/storage/badgerdb" "github.com/ortuman/jackal/storage/mysql" "github.com/ortuman/jackal/storage/pgsql" ) @@ -24,9 +23,6 @@ const ( // PostgreSQL represents a PostgreSQL storage type. PostgreSQL - // BadgerDB represents a BadgerDB storage type. - BadgerDB - // Memory represents a in-memstorage storage type. Memory ) @@ -34,7 +30,6 @@ const ( var typeStringMap = map[Type]string{ MySQL: "MySQL", PostgreSQL: "PostgreSQL", - BadgerDB: "BadgerDB", Memory: "Memory", } @@ -45,14 +40,12 @@ type Config struct { Type Type MySQL *mysql.Config PostgreSQL *pgsql.Config - BadgerDB *badgerdb.Config } type storageProxyType struct { - Type string `yaml:"type"` - MySQL *mysql.Config `yaml:"mysql"` - PostgreSQL *pgsql.Config `yaml:"pgsql"` - BadgerDB *badgerdb.Config `yaml:"badgerdb"` + Type string `yaml:"type"` + MySQL *mysql.Config `yaml:"mysql"` + PostgreSQL *pgsql.Config `yaml:"pgsql"` } // UnmarshalYAML satisfies Unmarshaler interface. @@ -78,13 +71,6 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { c.Type = PostgreSQL c.PostgreSQL = p.PostgreSQL - case "badgerdb": - if p.BadgerDB == nil { - return errors.New("storage.Config: couldn't read BadgerDB configuration") - } - c.Type = BadgerDB - c.BadgerDB = p.BadgerDB - case "memory": c.Type = Memory diff --git a/storage/config_test.go b/storage/config_test.go index a533854de..64a20d7b1 100644 --- a/storage/config_test.go +++ b/storage/config_test.go @@ -8,7 +8,6 @@ package storage import ( "testing" - "github.com/ortuman/jackal/storage/badgerdb" "github.com/ortuman/jackal/storage/mysql" "github.com/stretchr/testify/require" yaml "gopkg.in/yaml.v2" @@ -67,17 +66,6 @@ func TestStorageConfig(t *testing.T) { ` err = yaml.Unmarshal([]byte(invalidCfg), &cfg) require.NotNil(t, err) - - // Test if BadgerDB config unmarshaller sets defaults - badgerCfg := ` - type: badgerdb - badgerdb: {} -` - - err = yaml.Unmarshal([]byte(badgerCfg), &cfg) - require.Nil(t, err) - require.NotNil(t, cfg.BadgerDB) - require.Equal(t, cfg.BadgerDB.DataDir, badgerdb.DefaultDataDir) } func TestStorageBadConfig(t *testing.T) { diff --git a/storage/disabled.go b/storage/disabled.go deleted file mode 100644 index ad22360bf..000000000 --- a/storage/disabled.go +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package storage - -import ( - "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/rostermodel" - "github.com/ortuman/jackal/xmpp" -) - -type disabledStorage struct{} - -func (*disabledStorage) InsertOrUpdateUser(user *model.User) error { return nil } -func (*disabledStorage) DeleteUser(username string) error { return nil } -func (*disabledStorage) FetchUser(username string) (*model.User, error) { return nil, nil } -func (*disabledStorage) UserExists(username string) (bool, error) { return false, nil } - -func (*disabledStorage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Version, error) { - return rostermodel.Version{}, nil -} - -func (*disabledStorage) DeleteRosterItem(username, jid string) (rostermodel.Version, error) { - return rostermodel.Version{}, nil -} - -func (*disabledStorage) FetchRosterItems(username string) ([]rostermodel.Item, rostermodel.Version, error) { - return nil, rostermodel.Version{}, nil -} - -func (*disabledStorage) FetchRosterItemsInGroups(username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { - return nil, rostermodel.Version{}, nil -} - -func (*disabledStorage) FetchRosterItem(username, jid string) (*rostermodel.Item, error) { - return nil, nil -} - -func (*disabledStorage) InsertOrUpdateRosterNotification(rn *rostermodel.Notification) error { - return nil -} - -func (*disabledStorage) DeleteRosterNotification(contact, jid string) error { - return nil -} - -func (*disabledStorage) FetchRosterNotification(contact string, jid string) (*rostermodel.Notification, error) { - return nil, nil -} - -func (*disabledStorage) FetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { - return nil, nil -} - -func (*disabledStorage) InsertOfflineMessage(message *xmpp.Message, username string) error { - return nil -} - -func (*disabledStorage) CountOfflineMessages(username string) (int, error) { - return 0, nil -} - -func (*disabledStorage) FetchOfflineMessages(username string) ([]xmpp.Message, error) { - return nil, nil -} - -func (*disabledStorage) DeleteOfflineMessages(username string) error { - return nil -} - -func (*disabledStorage) InsertOrUpdateVCard(vCard xmpp.XElement, username string) error { - return nil -} - -func (*disabledStorage) FetchVCard(username string) (xmpp.XElement, error) { - return nil, nil -} - -func (*disabledStorage) FetchPrivateXML(namespace string, username string) ([]xmpp.XElement, error) { - return nil, nil -} - -func (*disabledStorage) InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace string, username string) error { - return nil -} - -func (*disabledStorage) InsertBlockListItems(items []model.BlockListItem) error { - return nil -} - -func (*disabledStorage) DeleteBlockListItems(items []model.BlockListItem) error { - return nil -} - -func (*disabledStorage) FetchBlockListItems(username string) ([]model.BlockListItem, error) { - return nil, nil -} - -func (*disabledStorage) IsClusterCompatible() bool { - return false -} - -func (*disabledStorage) Close() error { - return nil -} diff --git a/storage/memory/block_list.go b/storage/memory/block_list.go new file mode 100644 index 000000000..6327fcba7 --- /dev/null +++ b/storage/memory/block_list.go @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + "github.com/ortuman/jackal/model" + "github.com/ortuman/jackal/model/serializer" +) + +// BlockList represents an in-memory block-list storage. +type BlockList struct { + *memoryStorage +} + +// NewBlockList returns an instance of BlockList in-memory storage. +func NewBlockList() *BlockList { + return &BlockList{memoryStorage: newStorage()} +} + +// InsertBlockListItem inserts a block list item entity into storage if not previously inserted. +func (m *BlockList) InsertBlockListItem(_ context.Context, item *model.BlockListItem) error { + return m.updateInWriteLock(blockListItemKey(item.Username), func(b []byte) ([]byte, error) { + var items []model.BlockListItem + if len(b) > 0 { + if err := serializer.DeserializeSlice(b, &items); err != nil { + return nil, err + } + } + for _, itm := range items { + if itm.JID == item.JID { + return b, nil // already inserted + } + } + items = append(items, *item) + + output, err := serializer.SerializeSlice(&items) + if err != nil { + return nil, err + } + return output, nil + }) +} + +// DeleteBlockListItem deletes a block list item entity from storage. +func (m *BlockList) DeleteBlockListItem(_ context.Context, item *model.BlockListItem) error { + return m.updateInWriteLock(blockListItemKey(item.Username), func(b []byte) ([]byte, error) { + var items []model.BlockListItem + if len(b) > 0 { + if err := serializer.DeserializeSlice(b, &items); err != nil { + return nil, err + } + } + for i, itm := range items { + if itm.JID == item.JID { + items = append(items[:i], items[i+1:]...) + + output, err := serializer.SerializeSlice(&items) + if err != nil { + return nil, err + } + return output, nil + } + } + return b, nil // not present + }) +} + +// FetchBlockListItems retrieves from storage all block list item entities associated to a given user. +func (m *BlockList) FetchBlockListItems(_ context.Context, username string) ([]model.BlockListItem, error) { + var items []model.BlockListItem + _, err := m.getEntities(blockListItemKey(username), &items) + if err != nil { + return nil, err + } + return items, nil +} + +func blockListItemKey(username string) string { + return "blockListItems:" + username +} diff --git a/storage/memory/block_list_test.go b/storage/memory/block_list_test.go new file mode 100644 index 000000000..f9821fa58 --- /dev/null +++ b/storage/memory/block_list_test.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/model" + "github.com/stretchr/testify/require" +) + +func TestMemoryStorage_InsertOrUpdateBlockListItems(t *testing.T) { + items := []model.BlockListItem{ + {Username: "ortuman", JID: "user@jackal.im"}, + {Username: "ortuman", JID: "romeo@jackal.im"}, + {Username: "ortuman", JID: "juliet@jackal.im"}, + } + s := NewBlockList() + EnableMockedError() + require.Equal(t, ErrMocked, s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "user@jackal.im"})) + DisableMockedError() + + require.Nil(t, s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "user@jackal.im"})) + require.Nil(t, s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "romeo@jackal.im"})) + require.Nil(t, s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "juliet@jackal.im"})) + + EnableMockedError() + _, err := s.FetchBlockListItems(context.Background(), "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() + + sItems, _ := s.FetchBlockListItems(context.Background(), "ortuman") + require.Equal(t, items, sItems) +} + +func TestMemoryStorage_DeleteBlockListItems(t *testing.T) { + s := NewBlockList() + require.Nil(t, s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "user@jackal.im"})) + require.Nil(t, s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "romeo@jackal.im"})) + require.Nil(t, s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "juliet@jackal.im"})) + + EnableMockedError() + require.Equal(t, ErrMocked, s.DeleteBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "romeo@jackal.im"})) + DisableMockedError() + + require.Nil(t, s.DeleteBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "romeo@jackal.im"})) + + sItems, _ := s.FetchBlockListItems(context.Background(), "ortuman") + require.Equal(t, []model.BlockListItem{ + {Username: "ortuman", JID: "user@jackal.im"}, + {Username: "ortuman", JID: "juliet@jackal.im"}, + }, sItems) +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go new file mode 100644 index 000000000..23b289d76 --- /dev/null +++ b/storage/memory/memory.go @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + "github.com/ortuman/jackal/storage/repository" +) + +type memoryContainer struct { + user *User + roster *Roster + presences *Presences + vCard *VCard + priv *Private + blockList *BlockList + pubSub *PubSub + offline *Offline + room *Room + occ *Occupant +} + +// New initializes in-memory storage and returns associated container. +func New() (repository.Container, error) { + var c memoryContainer + + c.user = NewUser() + c.roster = NewRoster() + c.presences = NewPresences() + c.vCard = NewVCard() + c.priv = NewPrivate() + c.blockList = NewBlockList() + c.pubSub = NewPubSub() + c.offline = NewOffline() + c.room = NewRoom() + c.occ = NewOccupant() + + return &c, nil +} + +func (c *memoryContainer) User() repository.User { return c.user } +func (c *memoryContainer) Roster() repository.Roster { return c.roster } +func (c *memoryContainer) Presences() repository.Presences { return c.presences } +func (c *memoryContainer) VCard() repository.VCard { return c.vCard } +func (c *memoryContainer) Private() repository.Private { return c.priv } +func (c *memoryContainer) BlockList() repository.BlockList { return c.blockList } +func (c *memoryContainer) PubSub() repository.PubSub { return c.pubSub } +func (c *memoryContainer) Offline() repository.Offline { return c.offline } + +func (c *memoryContainer) Close(_ context.Context) error { return nil } + +func (c *memoryContainer) IsClusterCompatible() bool { return false } +func (c *memoryContainer) Room() repository.Room { return c.room } +func (c *memoryContainer) Occupant() repository.Occupant { return c.occ } diff --git a/storage/memory/occupant.go b/storage/memory/occupant.go new file mode 100644 index 000000000..8a9415bfd --- /dev/null +++ b/storage/memory/occupant.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +// Room represents an in-memory room storage. +type Occupant struct { + *memoryStorage +} + +// NewOccupant returns an instance of Occupant in-memory storage. +func NewOccupant() *Occupant { + return &Occupant{memoryStorage: newStorage()} +} + +// UpsertOccupant inserts a new occupant entity into storage, or updates the existing occupant. +func (m *Occupant) UpsertOccupant(_ context.Context, occ *mucmodel.Occupant) error { + return m.saveEntity(occKey(occ.OccupantJID), occ) +} + +// DeleteOccupant deletes an occupant entity from storage. +func (m *Occupant) DeleteOccupant(_ context.Context, occJID *jid.JID) error { + return m.deleteKey(occKey(occJID)) +} + +// FetchOccupant retrieves from storage an occupant entity. +func (m *Occupant) FetchOccupant(_ context.Context, occJID *jid.JID) (*mucmodel.Occupant, error) { + var occ mucmodel.Occupant + ok, err := m.getEntity(occKey(occJID), &occ) + switch err { + case nil: + if ok { + return &occ, nil + } + return nil, nil + default: + return nil, err + } +} + +// OccupantExists returns whether or not an occupant exists within storage. +func (m *Occupant) OccupantExists(_ context.Context, occJID *jid.JID) (bool, error) { + return m.keyExists(occKey(occJID)) +} + +func occKey(occJID *jid.JID) string { + return "jid" + occJID.String() +} diff --git a/storage/memory/occupant_test.go b/storage/memory/occupant_test.go new file mode 100644 index 000000000..c603cc173 --- /dev/null +++ b/storage/memory/occupant_test.go @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "testing" + + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMemoryStorage_InsertOccupant(t *testing.T) { + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + o, _ := mucmodel.NewOccupant(j, j.ToBareJID()) + o.AddResource("yard") + o.SetAffiliation("owner") + o.SetRole("moderator") + s := NewOccupant() + EnableMockedError() + err := s.UpsertOccupant(context.Background(), o) + require.Equal(t, ErrMocked, err) + DisableMockedError() + + err = s.UpsertOccupant(context.Background(), o) + require.Nil(t, err) +} + +func TestMemoryStorage_OccupantExists(t *testing.T) { + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + s := NewOccupant() + EnableMockedError() + _, err := s.OccupantExists(context.Background(), j) + require.Equal(t, ErrMocked, err) + DisableMockedError() + + ok, err := s.OccupantExists(context.Background(), j) + require.Nil(t, err) + require.False(t, ok) + + o, _ := mucmodel.NewOccupant(j, j.ToBareJID()) + o.AddResource("yard") + o.SetAffiliation("owner") + o.SetRole("moderator") + s.saveEntity(occKey(j), o) + ok, err = s.OccupantExists(context.Background(), j) + require.Nil(t, err) + require.True(t, ok) +} + +func TestMemoryStorage_FetchOccupant(t *testing.T) { + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + o, _ := mucmodel.NewOccupant(j, j.ToBareJID()) + o.AddResource("yard") + o.SetAffiliation("owner") + o.SetRole("moderator") + s := NewOccupant() + _ = s.UpsertOccupant(context.Background(), o) + + EnableMockedError() + _, err := s.FetchOccupant(context.Background(), j) + require.Equal(t, ErrMocked, err) + DisableMockedError() + + notInMemoryJID, _ := jid.NewWithString("romeo@jackal.im/yard", true) + occ, _ := s.FetchOccupant(context.Background(), notInMemoryJID) + require.Nil(t, occ) + + occ, _ = s.FetchOccupant(context.Background(), j) + require.NotNil(t, occ) + assert.EqualValues(t, o, occ) +} + +func TestMemoryStorage_DeleteOccupant(t *testing.T) { + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + o, _ := mucmodel.NewOccupant(j, j.ToBareJID()) + o.AddResource("yard") + o.SetAffiliation("owner") + o.SetRole("moderator") + s := NewOccupant() + _ = s.UpsertOccupant(context.Background(), o) + + EnableMockedError() + require.Equal(t, ErrMocked, s.DeleteOccupant(context.Background(), j)) + DisableMockedError() + require.Nil(t, s.DeleteOccupant(context.Background(), j)) + + occ, _ := s.FetchOccupant(context.Background(), j) + require.Nil(t, occ) +} diff --git a/storage/memory/offline.go b/storage/memory/offline.go new file mode 100644 index 000000000..51af736b6 --- /dev/null +++ b/storage/memory/offline.go @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + "github.com/ortuman/jackal/model/serializer" + "github.com/ortuman/jackal/xmpp" +) + +// Offline represents an in-memory offline storage. +type Offline struct { + *memoryStorage +} + +// NewOffline returns an instance of Offline in-memory storage. +func NewOffline() *Offline { + return &Offline{memoryStorage: newStorage()} +} + +// InsertOfflineMessage inserts a new message element into user's offline queue. +func (m *Offline) InsertOfflineMessage(_ context.Context, message *xmpp.Message, username string) error { + return m.updateInWriteLock(offlineMessageKey(username), func(b []byte) ([]byte, error) { + var messages []xmpp.Message + if len(b) > 0 { + if err := serializer.DeserializeSlice(b, &messages); err != nil { + return nil, err + } + } + messages = append(messages, *message) + + b, err := serializer.SerializeSlice(&messages) + if err != nil { + return nil, err + } + return b, nil + }) +} + +// CountOfflineMessages returns current length of user's offline queue. +func (m *Offline) CountOfflineMessages(_ context.Context, username string) (int, error) { + var messages []xmpp.Message + _, err := m.getEntities(offlineMessageKey(username), &messages) + if err != nil { + return 0, err + } + return len(messages), nil +} + +// FetchOfflineMessages retrieves from storage current user offline queue. +func (m *Offline) FetchOfflineMessages(_ context.Context, username string) ([]xmpp.Message, error) { + var messages []xmpp.Message + _, err := m.getEntities(offlineMessageKey(username), &messages) + switch err { + case nil: + return messages, nil + default: + return nil, err + } +} + +// DeleteOfflineMessages clears a user offline queue. +func (m *Offline) DeleteOfflineMessages(_ context.Context, username string) error { + return m.deleteKey(offlineMessageKey(username)) +} + +func offlineMessageKey(username string) string { + return "offlineMessages:" + username +} diff --git a/storage/memstorage/offline_test.go b/storage/memory/offline_test.go similarity index 56% rename from storage/memstorage/offline_test.go rename to storage/memory/offline_test.go index dc5959d2e..b7eab7189 100644 --- a/storage/memstorage/offline_test.go +++ b/storage/memory/offline_test.go @@ -3,9 +3,10 @@ * See the LICENSE file for more information. */ -package memstorage +package memorystorage import ( + "context" "testing" "github.com/ortuman/jackal/xmpp" @@ -21,11 +22,12 @@ func TestMemoryStorage_InsertOfflineMessage(t *testing.T) { message.AppendElement(xmpp.NewElementName("body")) m, _ := xmpp.NewMessageFromElement(message, j, j) - s := New() - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.InsertOfflineMessage(m, "ortuman")) - s.DisableMockedError() - require.Nil(t, s.InsertOfflineMessage(m, "ortuman")) + s := NewOffline() + EnableMockedError() + require.Equal(t, ErrMocked, s.InsertOfflineMessage(context.Background(), m, "ortuman")) + DisableMockedError() + + require.Nil(t, s.InsertOfflineMessage(context.Background(), m, "ortuman")) } func TestMemoryStorage_CountOfflineMessages(t *testing.T) { @@ -35,14 +37,15 @@ func TestMemoryStorage_CountOfflineMessages(t *testing.T) { message.AppendElement(xmpp.NewElementName("body")) m, _ := xmpp.NewMessageFromElement(message, j, j) - s := New() - _ = s.InsertOfflineMessage(m, "ortuman") + s := NewOffline() + _ = s.InsertOfflineMessage(context.Background(), m, "ortuman") + + EnableMockedError() + _, err := s.CountOfflineMessages(context.Background(), "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() - s.EnableMockedError() - _, err := s.CountOfflineMessages("ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - cnt, _ := s.CountOfflineMessages("ortuman") + cnt, _ := s.CountOfflineMessages(context.Background(), "ortuman") require.Equal(t, 1, cnt) } @@ -53,14 +56,14 @@ func TestMemoryStorage_FetchOfflineMessages(t *testing.T) { message.AppendElement(xmpp.NewElementName("body")) m, _ := xmpp.NewMessageFromElement(message, j, j) - s := New() - _ = s.InsertOfflineMessage(m, "ortuman") + s := NewOffline() + _ = s.InsertOfflineMessage(context.Background(), m, "ortuman") - s.EnableMockedError() - _, err := s.FetchOfflineMessages("ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - elems, _ := s.FetchOfflineMessages("ortuman") + EnableMockedError() + _, err := s.FetchOfflineMessages(context.Background(), "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() + elems, _ := s.FetchOfflineMessages(context.Background(), "ortuman") require.Equal(t, 1, len(elems)) } @@ -71,14 +74,14 @@ func TestMemoryStorage_DeleteOfflineMessages(t *testing.T) { message.AppendElement(xmpp.NewElementName("body")) m, _ := xmpp.NewMessageFromElement(message, j, j) - s := New() - _ = s.InsertOfflineMessage(m, "ortuman") + s := NewOffline() + _ = s.InsertOfflineMessage(context.Background(), m, "ortuman") - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.DeleteOfflineMessages("ortuman")) - s.DisableMockedError() - require.Nil(t, s.DeleteOfflineMessages("ortuman")) + EnableMockedError() + require.Equal(t, ErrMocked, s.DeleteOfflineMessages(context.Background(), "ortuman")) + DisableMockedError() + require.Nil(t, s.DeleteOfflineMessages(context.Background(), "ortuman")) - elems, _ := s.FetchOfflineMessages("ortuman") + elems, _ := s.FetchOfflineMessages(context.Background(), "ortuman") require.Equal(t, 0, len(elems)) } diff --git a/storage/memory/presences.go b/storage/memory/presences.go new file mode 100644 index 000000000..84cfa061e --- /dev/null +++ b/storage/memory/presences.go @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "strings" + + capsmodel "github.com/ortuman/jackal/model/capabilities" + "github.com/ortuman/jackal/model/serializer" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +type Presences struct { + *memoryStorage +} + +// NewPresences returns an instance of Presences in-memory storage. +func NewPresences() *Presences { + return &Presences{memoryStorage: newStorage()} +} + +// UpsertPresence inserts or updates a presence and links it to certain allocation. +func (m *Presences) UpsertPresence(_ context.Context, presence *xmpp.Presence, jid *jid.JID, _ string) (inserted bool, err error) { + var ok bool + k := presenceKey(jid) + if err := m.inWriteLock(func() error { + _, ok = m.b[k] + b, err := serializer.Serialize(presence) + if err != nil { + return err + } + m.b[k] = b + return nil + }); err != nil { + return false, err + } + return !ok, nil +} + +// FetchPresence retrieves from storage a concrete registered presence. +func (m *Presences) FetchPresence(_ context.Context, jid *jid.JID) (*capsmodel.PresenceCaps, error) { + var pCaps *capsmodel.PresenceCaps + + if err := m.inReadLock(func() error { + b := m.b[presenceKey(jid)] + if b == nil { + return nil + } + presenceCaps, err := m.deserializePresence(b) + if err != nil { + return err + } + pCaps = presenceCaps + return nil + }); err != nil { + return nil, err + } + return pCaps, nil +} + +// FetchPresencesMatchingJID retrives all storage presences matching a certain JID +func (m *Presences) FetchPresencesMatchingJID(ctx context.Context, j *jid.JID) ([]capsmodel.PresenceCaps, error) { + var usePrefix, useSuffix bool + var res []capsmodel.PresenceCaps + + if j.IsFullWithUser() { + pCaps, err := m.FetchPresence(ctx, j) + if err != nil { + return nil, err + } + if pCaps == nil { + return nil, nil + } + return []capsmodel.PresenceCaps{*pCaps}, nil + } + usePrefix = j.IsBare() + useSuffix = j.IsFullWithServer() + + if err := m.inReadLock(func() error { + for k, v := range m.b { + if !strings.HasPrefix(k, "presences:") { + continue + } + kJID, _ := jid.NewWithString(k[10:], true) + if usePrefix { + if !j.MatchesWithOptions(kJID, jid.MatchesBare) { + continue + } + } else if useSuffix { + if !j.MatchesWithOptions(kJID, jid.MatchesDomain|jid.MatchesResource) { + continue + } + } else if !j.MatchesWithOptions(kJID, jid.MatchesDomain) { + continue + } + presenceCaps, err := m.deserializePresence(v) + if err != nil { + return err + } + res = append(res, *presenceCaps) + } + return nil + }); err != nil { + return nil, err + } + return res, nil +} + +// DeletePresence removes from storage a concrete registered presence. +func (m *Presences) DeletePresence(_ context.Context, jid *jid.JID) error { + return m.deleteKey(presenceKey(jid)) +} + +func (m *Presences) DeleteAllocationPresences(ctx context.Context, _ string) error { + return m.ClearPresences(ctx) +} + +func (m *Presences) ClearPresences(_ context.Context) error { + return m.inWriteLock(func() error { + for k := range m.b { + if !strings.HasPrefix(k, "presences:") { + continue + } + delete(m.b, k) + } + return nil + }) +} + +func (m *Presences) UpsertCapabilities(_ context.Context, caps *capsmodel.Capabilities) error { + return m.saveEntity(capabilitiesKey(caps.Node, caps.Ver), caps) +} + +func (m *Presences) FetchCapabilities(_ context.Context, node, ver string) (*capsmodel.Capabilities, error) { + var caps capsmodel.Capabilities + + ok, err := m.getEntity(capabilitiesKey(node, ver), &caps) + switch err { + case nil: + if !ok { + return nil, nil + } + return &caps, nil + default: + return nil, err + } +} + +func (m *Presences) deserializePresence(b []byte) (*capsmodel.PresenceCaps, error) { + var pCaps capsmodel.PresenceCaps + var presence xmpp.Presence + + if err := serializer.Deserialize(b, &presence); err != nil { + return nil, err + } + pCaps.Presence = &presence + if c := presence.Capabilities(); c != nil { + if capsB := m.b[capabilitiesKey(c.Node, c.Ver)]; capsB != nil { + var caps capsmodel.Capabilities + if err := serializer.Deserialize(capsB, &caps); err != nil { + return nil, err + } + pCaps.Caps = &caps + } + } + return &pCaps, nil +} + +func presenceKey(jid *jid.JID) string { + return "presences:" + jid.String() +} + +func capabilitiesKey(node, ver string) string { + return "capabilities:" + node + ":" + ver +} diff --git a/storage/memory/presences_test.go b/storage/memory/presences_test.go new file mode 100644 index 000000000..958f13499 --- /dev/null +++ b/storage/memory/presences_test.go @@ -0,0 +1,97 @@ +package memorystorage + +import ( + "context" + "testing" + + capsmodel "github.com/ortuman/jackal/model/capabilities" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestMemoryStorage_FetchPresencesMatchingJID(t *testing.T) { + const allocID = "1234" + + j1, _ := jid.NewWithString("noelia@jackal.im/garden", true) + j2, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + j3, _ := jid.NewWithString("noelia@jackal.im/yard", true) + j4, _ := jid.NewWithString("boss@jabber.org/balcony", true) + + p1 := xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.AvailableType) + p2 := xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.AvailableType) + p3 := xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.AvailableType) + p4 := xmpp.NewPresence(j1, j1.ToBareJID(), xmpp.AvailableType) + + s := NewPresences() + ok, err := s.UpsertPresence(context.Background(), p1, j1, allocID) + require.True(t, ok) + require.Nil(t, err) + + ok, err = s.UpsertPresence(context.Background(), p2, j2, allocID) + require.True(t, ok) + require.Nil(t, err) + + ok, err = s.UpsertPresence(context.Background(), p3, j3, allocID) + require.True(t, ok) + require.Nil(t, err) + + ok, err = s.UpsertPresence(context.Background(), p4, j4, allocID) + require.True(t, ok) + require.Nil(t, err) + + // updating presence + ok, err = s.UpsertPresence(context.Background(), p1, j1, allocID) + require.False(t, ok) + require.Nil(t, err) + + mJID, _ := jid.NewWithString("jackal.im", true) + presences, _ := s.FetchPresencesMatchingJID(context.Background(), mJID) + require.Len(t, presences, 3) + + mJID, _ = jid.NewWithString("jackal.im/yard", true) + presences, _ = s.FetchPresencesMatchingJID(context.Background(), mJID) + require.Len(t, presences, 2) + + mJID, _ = jid.NewWithString("jabber.org", true) + presences, _ = s.FetchPresencesMatchingJID(context.Background(), mJID) + require.Len(t, presences, 1) + + _ = s.DeletePresence(context.Background(), j2) + mJID, _ = jid.NewWithString("jackal.im/yard", true) + presences, _ = s.FetchPresencesMatchingJID(context.Background(), mJID) + require.Len(t, presences, 1) + + _ = s.ClearPresences(context.Background()) + mJID, _ = jid.NewWithString("jackal.im", true) + presences, _ = s.FetchPresencesMatchingJID(context.Background(), mJID) + require.Len(t, presences, 0) +} + +func TestMemoryStorage_InsertCapabilities(t *testing.T) { + caps := capsmodel.Capabilities{Node: "n1", Ver: "1234A", Features: []string{"ns"}} + s := NewPresences() + EnableMockedError() + err := s.UpsertCapabilities(context.Background(), &caps) + require.Equal(t, ErrMocked, err) + DisableMockedError() + err = s.UpsertCapabilities(context.Background(), &caps) + require.Nil(t, err) +} + +func TestMemoryStorage_FetchCapabilities(t *testing.T) { + caps := capsmodel.Capabilities{Node: "n1", Ver: "1234A", Features: []string{"ns"}} + s := NewPresences() + _ = s.UpsertCapabilities(context.Background(), &caps) + + EnableMockedError() + _, err := s.FetchCapabilities(context.Background(), "n1", "1234A") + require.Equal(t, ErrMocked, err) + DisableMockedError() + + cs, _ := s.FetchCapabilities(context.Background(), "n1", "1234B") + require.Nil(t, cs) + + cs, _ = s.FetchCapabilities(context.Background(), "n1", "1234A") + require.NotNil(t, cs) +} diff --git a/storage/memory/private.go b/storage/memory/private.go new file mode 100644 index 000000000..2e2ceff95 --- /dev/null +++ b/storage/memory/private.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + "github.com/ortuman/jackal/xmpp" +) + +// Private represents an in-memory private storage. +type Private struct { + *memoryStorage +} + +// NewPrivate returns an instance of Private in-memory storage. +func NewPrivate() *Private { + return &Private{memoryStorage: newStorage()} +} + +// UpsertPrivateXML inserts a new private element into storage, or updates it in case it's been previously inserted. +func (m *Private) UpsertPrivateXML(_ context.Context, privateXML []xmpp.XElement, namespace string, username string) error { + var priv []xmpp.Element + + // convert to concrete type + for _, el := range privateXML { + priv = append(priv, *xmpp.NewElementFromElement(el)) + } + return m.saveEntities(privateStorageKey(username, namespace), &priv) +} + +// FetchPrivateXML retrieves from storage a private element. +func (m *Private) FetchPrivateXML(_ context.Context, namespace string, username string) ([]xmpp.XElement, error) { + var priv []xmpp.Element + _, err := m.getEntities(privateStorageKey(username, namespace), &priv) + if err != nil { + return nil, err + } + var ret []xmpp.XElement + for _, e := range priv { + ret = append(ret, &e) + } + return ret, nil +} + +func privateStorageKey(username, namespace string) string { + return "privateElements:" + username + ":" + namespace +} diff --git a/storage/memory/private_test.go b/storage/memory/private_test.go new file mode 100644 index 000000000..b66b27cd2 --- /dev/null +++ b/storage/memory/private_test.go @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/xmpp" + "github.com/stretchr/testify/require" +) + +func TestMemoryStorage_InsertPrivateXML(t *testing.T) { + private := xmpp.NewElementNamespace("exodus", "exodus:ns") + + s := NewPrivate() + EnableMockedError() + err := s.UpsertPrivateXML(context.Background(), []xmpp.XElement{private}, "exodus:ns", "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() + + err = s.UpsertPrivateXML(context.Background(), []xmpp.XElement{private}, "exodus:ns", "ortuman") + require.Nil(t, err) +} + +func TestMemoryStorage_FetchPrivateXML(t *testing.T) { + private := xmpp.NewElementNamespace("exodus", "exodus:ns") + + s := NewPrivate() + _ = s.UpsertPrivateXML(context.Background(), []xmpp.XElement{private}, "exodus:ns", "ortuman") + + EnableMockedError() + _, err := s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() + + elems, _ := s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") + require.Equal(t, 1, len(elems)) +} diff --git a/storage/memory/pubsub.go b/storage/memory/pubsub.go new file mode 100644 index 000000000..181ffa266 --- /dev/null +++ b/storage/memory/pubsub.go @@ -0,0 +1,519 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "strings" + + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + "github.com/ortuman/jackal/model/serializer" +) + +// PubSub represents an in-memory pubsub storage. +type PubSub struct { + *memoryStorage +} + +// NewPubSub returns an instance of PubSub in-memory storage. +func NewPubSub() *PubSub { + return &PubSub{memoryStorage: newStorage()} +} + +// FetchHosts returns all host identifiers. +func (m *PubSub) FetchHosts(_ context.Context) ([]string, error) { + var hosts []string + if err := m.inReadLock(func() error { + for k := range m.b { + if !strings.HasPrefix(k, "pubSubHostNodes:") { + continue + } + keySplits := strings.Split(k, ":") + if len(keySplits) != 2 { + continue + } + host := keySplits[1] + + var isPresent bool + for _, h := range hosts { + if h == host { + isPresent = true + break + } + } + if isPresent { + continue + } + hosts = append(hosts, host) + } + return nil + }); err != nil { + return nil, err + } + return hosts, nil +} + +// UpsertNode inserts a new pubsub node entity into storage, or updates it if previously inserted. +func (m *PubSub) UpsertNode(_ context.Context, node *pubsubmodel.Node) error { + b, err := serializer.Serialize(node) + if err != nil { + return err + } + return m.inWriteLock(func() error { + m.b[pubSubNodesKey(node.Host, node.Name)] = b + return m.upsertHostNode(node) + }) +} + +// FetchNodes retrieves from storage all node entities associated with a host. +func (m *PubSub) FetchNodes(_ context.Context, host string) ([]pubsubmodel.Node, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[pubSubHostNodesKey(host)] + return nil + }); err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + var nodes []pubsubmodel.Node + + if err := serializer.DeserializeSlice(b, &nodes); err != nil { + return nil, err + } + return nodes, nil +} + +// FetchNode retrieves from storage a pubsub node entity. +func (m *PubSub) FetchNode(_ context.Context, host, name string) (*pubsubmodel.Node, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[pubSubNodesKey(host, name)] + return nil + }); err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + var node pubsubmodel.Node + if err := serializer.Deserialize(b, &node); err != nil { + return nil, err + } + return &node, nil +} + +// FetchSubscribedNodes retrieves from storage all nodes to which a given jid is subscribed. +func (m *PubSub) FetchSubscribedNodes(_ context.Context, jid string) ([]pubsubmodel.Node, error) { + var nodes []pubsubmodel.Node + if err := m.inReadLock(func() error { + for k, b := range m.b { + if !strings.HasPrefix(k, "pubSubSubscriptions:") { + continue + } + keySplits := strings.Split(k, ":") + if len(keySplits) != 3 { + continue // wrong key format + } + host := keySplits[1] + name := keySplits[2] + + var subs []pubsubmodel.Subscription + if b != nil { + if err := serializer.DeserializeSlice(b, &subs); err != nil { + return err + } + } + for _, sub := range subs { + if sub.JID != jid || sub.Subscription != pubsubmodel.Subscribed { + continue + } + // fetch pubsub node + var node pubsubmodel.Node + + b := m.b[pubSubNodesKey(host, name)] + if b == nil { + continue + } + if err := serializer.Deserialize(b, &node); err != nil { + return err + } + nodes = append(nodes, node) + break + } + } + return nil + }); err != nil { + return nil, err + } + return nodes, nil +} + +// DeleteNode deletes a pubsub node from storage. +func (m *PubSub) DeleteNode(_ context.Context, host, name string) error { + return m.inWriteLock(func() error { + delete(m.b, pubSubNodesKey(host, name)) + delete(m.b, pubSubItemsKey(host, name)) + delete(m.b, pubSubAffiliationsKey(host, name)) + return m.deleteHostNode(host, name) + }) +} + +// UpsertNodeItem inserts a new pubsub node item entity into storage, or updates it if previously inserted. +func (m *PubSub) UpsertNodeItem(_ context.Context, item *pubsubmodel.Item, host, name string, maxNodeItems int) error { + return m.inWriteLock(func() error { + var b []byte + var items []pubsubmodel.Item + + b = m.b[pubSubItemsKey(host, name)] + if b != nil { + if err := serializer.DeserializeSlice(b, &items); err != nil { + return err + } + } + var updated bool + for i, itm := range items { + if itm.ID == item.ID { + items[i] = *item + updated = true + break + } + } + if !updated { + items = append(items, *item) + } + if len(items) > maxNodeItems { + items = items[len(items)-maxNodeItems:] // remove oldest elements + } + b, err := serializer.SerializeSlice(&items) + if err != nil { + return err + } + m.b[pubSubItemsKey(host, name)] = b + return nil + }) +} + +// FetchNodeItems retrieves all items associated to a node. +func (m *PubSub) FetchNodeItems(_ context.Context, host, name string) ([]pubsubmodel.Item, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[pubSubItemsKey(host, name)] + return nil + }); err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + var items []pubsubmodel.Item + if err := serializer.DeserializeSlice(b, &items); err != nil { + return nil, err + } + return items, nil +} + +// FetchNodeItemsWithIDs retrieves all items matching any of the passed identifiers. +func (m *PubSub) FetchNodeItemsWithIDs(_ context.Context, host, name string, identifiers []string) ([]pubsubmodel.Item, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[pubSubItemsKey(host, name)] + return nil + }); err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + identifiersSet := make(map[string]struct{}) + for _, id := range identifiers { + identifiersSet[id] = struct{}{} + } + var filteredItems, items []pubsubmodel.Item + if err := serializer.DeserializeSlice(b, &items); err != nil { + return nil, err + } + for _, itm := range items { + if _, ok := identifiersSet[itm.ID]; ok { + filteredItems = append(filteredItems, itm) + } + } + return filteredItems, nil +} + +// FetchNodeLastItem retrieves last published node item. +func (m *PubSub) FetchNodeLastItem(_ context.Context, host, name string) (*pubsubmodel.Item, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[pubSubItemsKey(host, name)] + return nil + }); err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + var items []pubsubmodel.Item + if err := serializer.DeserializeSlice(b, &items); err != nil { + return nil, err + } + return &items[len(items)-1], nil +} + +// UpsertNodeAffiliation inserts a new pubsub node affiliation into storage, or updates it if previously inserted. +func (m *PubSub) UpsertNodeAffiliation(_ context.Context, affiliation *pubsubmodel.Affiliation, host, name string) error { + return m.inWriteLock(func() error { + var b []byte + var affiliations []pubsubmodel.Affiliation + + b = m.b[pubSubAffiliationsKey(host, name)] + if b != nil { + if err := serializer.DeserializeSlice(b, &affiliations); err != nil { + return err + } + } + var updated bool + for i, aff := range affiliations { + if aff.JID == affiliation.JID { + affiliations[i] = *affiliation + updated = true + break + } + } + if !updated { + affiliations = append(affiliations, *affiliation) + } + b, err := serializer.SerializeSlice(&affiliations) + if err != nil { + return err + } + m.b[pubSubAffiliationsKey(host, name)] = b + return nil + }) +} + +// FetchNodeAffiliation retrieves a concrete node affiliation from storage. +func (m *PubSub) FetchNodeAffiliation(ctx context.Context, host, name, jid string) (*pubsubmodel.Affiliation, error) { + affiliations, err := m.FetchNodeAffiliations(ctx, host, name) + if err != nil { + return nil, err + } + for _, aff := range affiliations { + if aff.JID == jid { + return &aff, nil + } + } + return nil, nil +} + +// FetchNodeAffiliations retrieves all affiliations associated to a node. +func (m *PubSub) FetchNodeAffiliations(_ context.Context, host, name string) ([]pubsubmodel.Affiliation, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[pubSubAffiliationsKey(host, name)] + return nil + }); err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + var affiliations []pubsubmodel.Affiliation + if err := serializer.DeserializeSlice(b, &affiliations); err != nil { + return nil, err + } + return affiliations, nil +} + +// DeleteNodeAffiliation deletes a pubsub node affiliation from storage. +func (m *PubSub) DeleteNodeAffiliation(_ context.Context, jid, host, name string) error { + return m.inWriteLock(func() error { + var b []byte + var affiliations []pubsubmodel.Affiliation + + b = m.b[pubSubAffiliationsKey(host, name)] + if b != nil { + if err := serializer.DeserializeSlice(b, &affiliations); err != nil { + return err + } + } + var deleted bool + for i, aff := range affiliations { + if aff.JID == jid { + affiliations = append(affiliations[:i], affiliations[i+1:]...) + deleted = true + break + } + } + if !deleted { + return nil + } + b, err := serializer.SerializeSlice(&affiliations) + if err != nil { + return err + } + m.b[pubSubAffiliationsKey(host, name)] = b + return nil + }) +} + +// UpsertNodeSubscription inserts a new pubsub node subscription into storage, or updates it if previously inserted. +func (m *PubSub) UpsertNodeSubscription(_ context.Context, subscription *pubsubmodel.Subscription, host, name string) error { + return m.inWriteLock(func() error { + var b []byte + var subscriptions []pubsubmodel.Subscription + + b = m.b[pubSubSubscriptionsKey(host, name)] + if b != nil { + if err := serializer.DeserializeSlice(b, &subscriptions); err != nil { + return err + } + } + var updated bool + for i, sub := range subscriptions { + if sub.JID == subscription.JID { + subscriptions[i] = *subscription + updated = true + break + } + } + if !updated { + subscriptions = append(subscriptions, *subscription) + } + b, err := serializer.SerializeSlice(&subscriptions) + if err != nil { + return err + } + m.b[pubSubSubscriptionsKey(host, name)] = b + return nil + }) +} + +// FetchNodeSubscriptions retrieves all subscriptions associated to a node. +func (m *PubSub) FetchNodeSubscriptions(_ context.Context, host, name string) ([]pubsubmodel.Subscription, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[pubSubSubscriptionsKey(host, name)] + return nil + }); err != nil { + return nil, err + } + if b == nil { + return nil, nil + } + var subscriptions []pubsubmodel.Subscription + if err := serializer.DeserializeSlice(b, &subscriptions); err != nil { + return nil, err + } + return subscriptions, nil +} + +// DeleteNodeSubscription deletes a pubsub node subscription from storage. +func (m *PubSub) DeleteNodeSubscription(_ context.Context, jid, host, name string) error { + return m.inWriteLock(func() error { + var b []byte + var subscriptions []pubsubmodel.Subscription + + b = m.b[pubSubSubscriptionsKey(host, name)] + if b != nil { + if err := serializer.DeserializeSlice(b, &subscriptions); err != nil { + return err + } + } + var deleted bool + for i, sub := range subscriptions { + if sub.JID == jid { + subscriptions = append(subscriptions[:i], subscriptions[i+1:]...) + deleted = true + break + } + } + if !deleted { + return nil + } + b, err := serializer.SerializeSlice(&subscriptions) + if err != nil { + return err + } + m.b[pubSubSubscriptionsKey(host, name)] = b + return nil + }) +} + +func (m *PubSub) upsertHostNode(node *pubsubmodel.Node) error { + var nodes []pubsubmodel.Node + + b := m.b[pubSubHostNodesKey(node.Host)] + if b != nil { + if err := serializer.DeserializeSlice(b, &nodes); err != nil { + return err + } + } + var updated bool + + for i, n := range nodes { + if n.Name == node.Name { + nodes[i] = *node + updated = true + break + } + } + if !updated { + nodes = append(nodes, *node) + } + + b, err := serializer.SerializeSlice(&nodes) + if err != nil { + return err + } + m.b[pubSubHostNodesKey(node.Host)] = b + return nil +} + +func (m *PubSub) deleteHostNode(host, name string) error { + var nodes []pubsubmodel.Node + + b := m.b[pubSubHostNodesKey(host)] + if b != nil { + if err := serializer.DeserializeSlice(b, &nodes); err != nil { + return err + } + } + for i, n := range nodes { + if n.Name == name { + nodes = append(nodes[:i], nodes[i+1:]...) + break + } + } + + b, err := serializer.SerializeSlice(&nodes) + if err != nil { + return err + } + m.b[pubSubHostNodesKey(host)] = b + return nil +} + +func pubSubHostNodesKey(host string) string { + return "pubSubHostNodes:" + host +} + +func pubSubNodesKey(host, name string) string { + return "pubSubNodes:" + host + ":" + name +} + +func pubSubAffiliationsKey(host, name string) string { + return "pubSubAffiliations:" + host + ":" + name +} + +func pubSubSubscriptionsKey(host, name string) string { + return "pubSubSubscriptions:" + host + ":" + name +} + +func pubSubItemsKey(host, name string) string { + return "pubSubItems:" + host + ":" + name +} diff --git a/storage/memory/pubsub_test.go b/storage/memory/pubsub_test.go new file mode 100644 index 000000000..b078a49f5 --- /dev/null +++ b/storage/memory/pubsub_test.go @@ -0,0 +1,246 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "reflect" + "testing" + + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + "github.com/ortuman/jackal/xmpp" + "github.com/stretchr/testify/require" +) + +func TestStorage_PubSubNode(t *testing.T) { + s := NewPubSub() + node := &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + } + require.Nil(t, s.UpsertNode(context.Background(), node)) + + n, err := s.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, n) + + require.True(t, reflect.DeepEqual(n, node)) + + node2 := &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings_2", + } + node3 := &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings_3", + } + node4 := &pubsubmodel.Node{ + Host: "noelia@jackal.im", + Name: "princely_musings_1", + } + require.Nil(t, s.UpsertNode(context.Background(), node2)) + require.Nil(t, s.UpsertNode(context.Background(), node3)) + require.Nil(t, s.UpsertNode(context.Background(), node4)) + + nodes, err := s.FetchNodes(context.Background(), "ortuman@jackal.im") + require.Nil(t, err) + require.NotNil(t, nodes) + + require.Len(t, nodes, 3) + require.Equal(t, "princely_musings", nodes[0].Name) + require.Equal(t, "princely_musings_2", nodes[1].Name) + require.Equal(t, "princely_musings_3", nodes[2].Name) + + require.Nil(t, s.DeleteNode(context.Background(), "ortuman@jackal.im", "princely_musings_2")) + + nodes, err = s.FetchNodes(context.Background(), "ortuman@jackal.im") + require.Nil(t, err) + require.NotNil(t, nodes) + + require.Len(t, nodes, 2) + require.Equal(t, "princely_musings", nodes[0].Name) + require.Equal(t, "princely_musings_3", nodes[1].Name) + + // fetch hosts + hosts, err := s.FetchHosts(context.Background()) + require.Nil(t, err) + require.Len(t, hosts, 2) +} + +func TestStorage_PubSubNodeItem(t *testing.T) { + s := NewPubSub() + item1 := &pubsubmodel.Item{ + ID: "id1", + Publisher: "ortuman@jackal.im", + Payload: xmpp.NewElementName("a"), + } + item2 := &pubsubmodel.Item{ + ID: "id2", + Publisher: "noelia@jackal.im", + Payload: xmpp.NewElementName("b"), + } + item3 := &pubsubmodel.Item{ + ID: "id3", + Publisher: "noelia@jackal.im", + Payload: xmpp.NewElementName("c"), + } + require.Nil(t, s.UpsertNodeItem(context.Background(), item1, "ortuman@jackal.im", "princely_musings", 1)) + require.Nil(t, s.UpsertNodeItem(context.Background(), item2, "ortuman@jackal.im", "princely_musings", 1)) + + items, err := s.FetchNodeItems(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, items) + + require.Len(t, items, 1) + require.True(t, reflect.DeepEqual(&items[0], item2)) + + // update item + require.Nil(t, s.UpsertNodeItem(context.Background(), item3, "ortuman@jackal.im", "princely_musings", 2)) + + items, err = s.FetchNodeItems(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, items) + + require.Len(t, items, 2) + require.True(t, reflect.DeepEqual(&items[0], item2)) + require.True(t, reflect.DeepEqual(&items[1], item3)) + + items, err = s.FetchNodeItemsWithIDs(context.Background(), "ortuman@jackal.im", "princely_musings", []string{"id3"}) + require.Nil(t, err) + require.NotNil(t, items) + + require.Len(t, items, 1) + require.Equal(t, "id3", items[0].ID) +} + +func TestStorage_PubSubNodeAffiliation(t *testing.T) { + s := NewPubSub() + aff1 := &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: "publisher", + } + aff2 := &pubsubmodel.Affiliation{ + JID: "noelia@jackal.im", + Affiliation: "publisher", + } + require.Nil(t, s.UpsertNodeAffiliation(context.Background(), aff1, "ortuman@jackal.im", "princely_musings")) + require.Nil(t, s.UpsertNodeAffiliation(context.Background(), aff2, "ortuman@jackal.im", "princely_musings")) + + affiliations, err := s.FetchNodeAffiliations(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, affiliations) + + require.Len(t, affiliations, 2) + + // update affiliation + aff2.Affiliation = "owner" + require.Nil(t, s.UpsertNodeAffiliation(context.Background(), aff2, "ortuman@jackal.im", "princely_musings")) + + affiliations, err = s.FetchNodeAffiliations(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, affiliations) + + require.Len(t, affiliations, 2) + + var updated bool + for _, aff := range affiliations { + if aff.JID == "noelia@jackal.im" { + require.Equal(t, "owner", aff.Affiliation) + updated = true + break + } + } + if !updated { + require.Fail(t, "affiliation for 'noelia@jackal.im' not found") + } + + // delete affiliation + err = s.DeleteNodeAffiliation(context.Background(), "noelia@jackal.im", "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + + affiliations, err = s.FetchNodeAffiliations(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, affiliations) + + require.Len(t, affiliations, 1) +} + +func TestStorage_PubSubNodeSubscription(t *testing.T) { + s := NewPubSub() + node := &pubsubmodel.Node{ + Host: "ortuman@jackal.im", + Name: "princely_musings", + } + _ = s.UpsertNode(context.Background(), node) + + node2 := &pubsubmodel.Node{ + Host: "noelia@jackal.im", + Name: "princely_musings", + } + _ = s.UpsertNode(context.Background(), node2) + + sub1 := &pubsubmodel.Subscription{ + SubID: "1234", + JID: "ortuman@jackal.im", + Subscription: "subscribed", + } + sub2 := &pubsubmodel.Subscription{ + SubID: "5678", + JID: "noelia@jackal.im", + Subscription: "unsubscribed", + } + sub3 := &pubsubmodel.Subscription{ + SubID: "9012", + JID: "ortuman@jackal.im", + Subscription: "subscribed", + } + require.Nil(t, s.UpsertNodeSubscription(context.Background(), sub1, "ortuman@jackal.im", "princely_musings")) + require.Nil(t, s.UpsertNodeSubscription(context.Background(), sub2, "ortuman@jackal.im", "princely_musings")) + require.Nil(t, s.UpsertNodeSubscription(context.Background(), sub3, "noelia@jackal.im", "princely_musings")) + + // fetch user subscribed nodes + nodes, err := s.FetchSubscribedNodes(context.Background(), "ortuman@jackal.im") + require.Nil(t, err) + require.Len(t, nodes, 2) + + subscriptions, err := s.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, subscriptions) + + require.Len(t, subscriptions, 2) + + // update affiliation + sub2.Subscription = "subscribed" + require.Nil(t, s.UpsertNodeSubscription(context.Background(), sub2, "ortuman@jackal.im", "princely_musings")) + + subscriptions, err = s.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, subscriptions) + + require.Len(t, subscriptions, 2) + + var updated bool + for _, sub := range subscriptions { + if sub.JID == "noelia@jackal.im" { + require.Equal(t, "subscribed", sub.Subscription) + updated = true + break + } + } + if !updated { + require.Fail(t, "subscription for 'noelia@jackal.im' not found") + } + + // delete subscription + err = s.DeleteNodeSubscription(context.Background(), "noelia@jackal.im", "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + + subscriptions, err = s.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, err) + require.NotNil(t, subscriptions) + + require.Len(t, subscriptions, 1) +} diff --git a/storage/memory/room.go b/storage/memory/room.go new file mode 100644 index 000000000..21812e477 --- /dev/null +++ b/storage/memory/room.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +// Room represents an in-memory room storage. +type Room struct { + *memoryStorage +} + +// NewRoom returns an instance of Room in-memory storage. +func NewRoom() *Room { + return &Room{memoryStorage: newStorage()} +} + +// UpsertRoom inserts a new room entity into storage, or updates the existing room. +func (m *Room) UpsertRoom(_ context.Context, room *mucmodel.Room) error { + return m.saveEntity(roomKey(room.RoomJID), room) +} + +// DeleteRoom deletes a room entity from storage. +func (m *Room) DeleteRoom(_ context.Context, roomJID *jid.JID) error { + return m.deleteKey(roomKey(roomJID)) +} + +// FetchRoom retrieves from storage a room entity. +func (m *Room) FetchRoom(_ context.Context, roomJID *jid.JID) (*mucmodel.Room, error) { + var room mucmodel.Room + ok, err := m.getEntity(roomKey(roomJID), &room) + switch err { + case nil: + if ok { + return &room, nil + } + return nil, nil + default: + return nil, err + } +} + +// RoomExists returns whether or not a room exists within storage. +func (m *Room) RoomExists(_ context.Context, roomJID *jid.JID) (bool, error) { + return m.keyExists(roomKey(roomJID)) +} + +func roomKey(roomJID *jid.JID) string { + return "rooms:" + roomJID.String() +} diff --git a/storage/memory/room_test.go b/storage/memory/room_test.go new file mode 100644 index 000000000..e5c04d016 --- /dev/null +++ b/storage/memory/room_test.go @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "testing" + + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMemoryStorage_InsertRoom(t *testing.T) { + r := GetTestRoom() + s := NewRoom() + EnableMockedError() + err := s.UpsertRoom(context.Background(), r) + require.Equal(t, ErrMocked, err) + DisableMockedError() + + err = s.UpsertRoom(context.Background(), r) + require.Nil(t, err) +} + +func TestMemoryStorage_RoomExists(t *testing.T) { + j, _ := jid.NewWithString("testroom@conference.jackal.im", true) + s := NewRoom() + EnableMockedError() + _, err := s.RoomExists(context.Background(), j) + require.Equal(t, ErrMocked, err) + DisableMockedError() + + ok, err := s.RoomExists(context.Background(), j) + require.Nil(t, err) + require.False(t, ok) + + r := GetTestRoom() + require.Equal(t, r.RoomJID, j) + s.saveEntity(roomKey(r.RoomJID), r) + ok, err = s.RoomExists(context.Background(), j) + require.Nil(t, err) + require.True(t, ok) +} + +func TestMemoryStorage_FetchRoom(t *testing.T) { + j, _ := jid.NewWithString("testroom@conference.jackal.im", true) + r := GetTestRoom() + s := NewRoom() + _ = s.UpsertRoom(context.Background(), r) + + EnableMockedError() + _, err := s.FetchRoom(context.Background(), j) + require.Equal(t, ErrMocked, err) + DisableMockedError() + + notInMemoryJID, _ := jid.NewWithString("faketestroom@conference.jackal.im", true) + roomFromMemory, _ := s.FetchRoom(context.Background(), notInMemoryJID) + require.Nil(t, roomFromMemory) + + roomFromMemory, _ = s.FetchRoom(context.Background(), j) + require.NotNil(t, roomFromMemory) + assert.EqualValues(t, r, roomFromMemory) +} + +func TestMemoryStorage_DeleteRoom(t *testing.T) { + j, _ := jid.NewWithString("testroom@conference.jackal.im", true) + r := GetTestRoom() + s := NewRoom() + _ = s.UpsertRoom(context.Background(), r) + + EnableMockedError() + require.Equal(t, ErrMocked, s.DeleteRoom(context.Background(), j)) + DisableMockedError() + require.Nil(t, s.DeleteRoom(context.Background(), j)) + + room, _ := s.FetchRoom(context.Background(), j) + require.Nil(t, room) +} + +func GetTestRoom() *mucmodel.Room { + rc := mucmodel.RoomConfig{ + Public: true, + Persistent: true, + PwdProtected: false, + Open: true, + Moderated: false, + } + j, _ := jid.NewWithString("testroom@conference.jackal.im", true) + + r := &mucmodel.Room{ + Name: "testRoom", + RoomJID: j, + Desc: "Room for Testing", + Config: &rc, + Locked: false, + } + + oJID, _ := jid.NewWithString("testroom@conference.jackal.im/owner", true) + owner, _ := mucmodel.NewOccupant(oJID, oJID.ToBareJID()) + r.AddOccupant(owner) + r.InviteUser(oJID.ToBareJID()) + + return r +} diff --git a/storage/memstorage/roster.go b/storage/memory/roster.go similarity index 56% rename from storage/memstorage/roster.go rename to storage/memory/roster.go index 8da06bc79..ab077d531 100644 --- a/storage/memstorage/roster.go +++ b/storage/memory/roster.go @@ -3,16 +3,29 @@ * See the LICENSE file for more information. */ -package memstorage +package memorystorage import ( - "github.com/ortuman/jackal/model/rostermodel" + "bytes" + "context" + "encoding/gob" + + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/model/serializer" ) -// InsertOrUpdateRosterItem inserts a new roster item entity into storage, -// or updates it in case it's been previously inserted. -func (m *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Version, error) { +// Roster represents an in-memory roster storage. +type Roster struct { + *memoryStorage +} + +// NewRoster returns an instance of Roster in-memory storage. +func NewRoster() *Roster { + return &Roster{memoryStorage: newStorage()} +} + +// UpsertRosterItem inserts a new roster item entity into storage, or updates it if previously inserted. +func (m *Roster) UpsertRosterItem(_ context.Context, ri *rostermodel.Item) (rostermodel.Version, error) { var rv rostermodel.Version err := m.inWriteLock(func() error { ris, fnErr := m.fetchRosterItems(ri.Username) @@ -32,6 +45,9 @@ func (m *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve } done: + if fnErr := m.upsertRosterGroups(ri.Username, ris); fnErr != nil { + return fnErr + } rv, fnErr = m.fetchRosterVersion(ri.Username) if fnErr != nil { return fnErr @@ -47,7 +63,7 @@ func (m *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve } // DeleteRosterItem deletes a roster item entity from storage. -func (m *Storage) DeleteRosterItem(user, contact string) (rostermodel.Version, error) { +func (m *Roster) DeleteRosterItem(_ context.Context, user, contact string) (rostermodel.Version, error) { var rv rostermodel.Version if err := m.inWriteLock(func() error { ris, fnErr := m.fetchRosterItems(user) @@ -64,6 +80,9 @@ func (m *Storage) DeleteRosterItem(user, contact string) (rostermodel.Version, e } } done: + if fnErr := m.upsertRosterGroups(user, ris); fnErr != nil { + return fnErr + } rv, fnErr = m.fetchRosterVersion(user) if fnErr != nil { return fnErr @@ -78,7 +97,7 @@ func (m *Storage) DeleteRosterItem(user, contact string) (rostermodel.Version, e } // FetchRosterItems retrieves from storage all roster item entities associated to a given user. -func (m *Storage) FetchRosterItems(user string) ([]rostermodel.Item, rostermodel.Version, error) { +func (m *Roster) FetchRosterItems(_ context.Context, user string) ([]rostermodel.Item, rostermodel.Version, error) { var ris []rostermodel.Item var rv rostermodel.Version @@ -96,9 +115,8 @@ func (m *Storage) FetchRosterItems(user string) ([]rostermodel.Item, rostermodel return ris, rv, nil } -// FetchRosterItemsInGroups retrieves from storage all roster item entities -// associated to a given user and a set of groups. -func (m *Storage) FetchRosterItemsInGroups(username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { +// FetchRosterItemsInGroups retrieves from storage all roster item entities associated to a given user and a set of groups. +func (m *Roster) FetchRosterItemsInGroups(_ context.Context, username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { var ris []rostermodel.Item var rv rostermodel.Version @@ -128,7 +146,7 @@ func (m *Storage) FetchRosterItemsInGroups(username string, groups []string) ([] } // FetchRosterItem retrieves from storage a roster item entity. -func (m *Storage) FetchRosterItem(user, contact string) (*rostermodel.Item, error) { +func (m *Roster) FetchRosterItem(_ context.Context, user, contact string) (*rostermodel.Item, error) { var ret *rostermodel.Item err := m.inReadLock(func() error { ris, fnErr := m.fetchRosterItems(user) @@ -146,9 +164,8 @@ func (m *Storage) FetchRosterItem(user, contact string) (*rostermodel.Item, erro return ret, err } -// InsertOrUpdateRosterNotification inserts a new roster notification entity -// into storage, or updates it in case it's been previously inserted. -func (m *Storage) InsertOrUpdateRosterNotification(rn *rostermodel.Notification) error { +// UpsertRosterNotification inserts a new roster notification entity into storage, or updates it if previously inserted. +func (m *Roster) UpsertRosterNotification(_ context.Context, rn *rostermodel.Notification) error { return m.inWriteLock(func() error { rns, fnErr := m.fetchRosterNotifications(rn.Contact) if fnErr != nil { @@ -171,7 +188,7 @@ func (m *Storage) InsertOrUpdateRosterNotification(rn *rostermodel.Notification) } // DeleteRosterNotification deletes a roster notification entity from storage. -func (m *Storage) DeleteRosterNotification(contact, jid string) error { +func (m *Roster) DeleteRosterNotification(_ context.Context, contact, jid string) error { return m.inWriteLock(func() error { rns, fnErr := m.fetchRosterNotifications(contact) if fnErr != nil { @@ -188,7 +205,7 @@ func (m *Storage) DeleteRosterNotification(contact, jid string) error { } // FetchRosterNotification retrieves from storage a roster notification entity. -func (m *Storage) FetchRosterNotification(contact string, jid string) (*rostermodel.Notification, error) { +func (m *Roster) FetchRosterNotification(_ context.Context, contact string, jid string) (*rostermodel.Notification, error) { var ret *rostermodel.Notification err := m.inReadLock(func() error { rns, fnErr := m.fetchRosterNotifications(contact) @@ -207,7 +224,7 @@ func (m *Storage) FetchRosterNotification(contact string, jid string) (*rostermo } // FetchRosterNotifications retrieves from storage all roster notifications associated to a given user. -func (m *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { +func (m *Roster) FetchRosterNotifications(_ context.Context, contact string) ([]rostermodel.Notification, error) { var rns []rostermodel.Notification if err := m.inReadLock(func() error { var fnErr error @@ -219,17 +236,30 @@ func (m *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notifi return rns, nil } -func (m *Storage) upsertRosterItems(ris []rostermodel.Item, user string) error { +// FetchRosterGroups retrieves all groups associated to a user roster. +func (m *Roster) FetchRosterGroups(_ context.Context, username string) ([]string, error) { + var groups []string + if err := m.inReadLock(func() error { + var fnErr error + groups, fnErr = m.fetchRosterGroups(username) + return fnErr + }); err != nil { + return nil, err + } + return groups, nil +} + +func (m *Roster) upsertRosterItems(ris []rostermodel.Item, user string) error { b, err := serializer.SerializeSlice(&ris) if err != nil { return err } - m.bytes[rosterItemsKey(user)] = b + m.b[rosterItemsKey(user)] = b return nil } -func (m *Storage) fetchRosterItems(user string) ([]rostermodel.Item, error) { - b := m.bytes[rosterItemsKey(user)] +func (m *Roster) fetchRosterItems(user string) ([]rostermodel.Item, error) { + b := m.b[rosterItemsKey(user)] if b == nil { return nil, nil } @@ -240,17 +270,17 @@ func (m *Storage) fetchRosterItems(user string) ([]rostermodel.Item, error) { return ris, nil } -func (m *Storage) upsertRosterVersion(rv rostermodel.Version, user string) error { +func (m *Roster) upsertRosterVersion(rv rostermodel.Version, user string) error { b, err := serializer.Serialize(&rv) if err != nil { return err } - m.bytes[rosterVersionKey(user)] = b + m.b[rosterVersionKey(user)] = b return nil } -func (m *Storage) fetchRosterVersion(user string) (rostermodel.Version, error) { - b := m.bytes[rosterVersionKey(user)] +func (m *Roster) fetchRosterVersion(user string) (rostermodel.Version, error) { + b := m.b[rosterVersionKey(user)] if b == nil { return rostermodel.Version{}, nil } @@ -261,17 +291,17 @@ func (m *Storage) fetchRosterVersion(user string) (rostermodel.Version, error) { return rv, nil } -func (m *Storage) upsertRosterNotifications(rns []rostermodel.Notification, contact string) error { +func (m *Roster) upsertRosterNotifications(rns []rostermodel.Notification, contact string) error { b, err := serializer.SerializeSlice(&rns) if err != nil { return err } - m.bytes[rosterNotificationsKey(contact)] = b + m.b[rosterNotificationsKey(contact)] = b return nil } -func (m *Storage) fetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { - b := m.bytes[rosterNotificationsKey(contact)] +func (m *Roster) fetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { + b := m.b[rosterNotificationsKey(contact)] if b == nil { return nil, nil } @@ -282,6 +312,57 @@ func (m *Storage) fetchRosterNotifications(contact string) ([]rostermodel.Notifi return rns, nil } +func (m *Roster) upsertRosterGroups(user string, ris []rostermodel.Item) error { + var groupsSet = make(map[string]struct{}) + // remove duplicates + for _, ri := range ris { + for _, group := range ri.Groups { + groupsSet[group] = struct{}{} + } + } + var groups []string + for group := range groupsSet { + groups = append(groups, group) + } + // encode groups + buf := bytes.NewBuffer(nil) + + enc := gob.NewEncoder(buf) + if err := enc.Encode(len(groups)); err != nil { + return err + } + for _, group := range groups { + if err := enc.Encode(group); err != nil { + return err + } + } + m.b[rosterGroupsKey(user)] = buf.Bytes() + return nil +} + +func (m *Roster) fetchRosterGroups(user string) ([]string, error) { + var ln int + var groups []string + + b := m.b[rosterGroupsKey(user)] + if b == nil { + return nil, nil + } + // decode groups + dec := gob.NewDecoder(bytes.NewReader(b)) + if err := dec.Decode(&ln); err != nil { + return nil, err + } + for i := 0; i < ln; i++ { + var group string + if err := dec.Decode(&group); err != nil { + return nil, err + } + groups = append(groups, group) + } + return groups, nil +} + func rosterItemsKey(user string) string { return "rosterItems:" + user } @@ -293,3 +374,7 @@ func rosterVersionKey(username string) string { func rosterNotificationsKey(contact string) string { return "rosterNotifications:" + contact } + +func rosterGroupsKey(username string) string { + return "rosterGroups:" + username +} diff --git a/storage/memstorage/roster_test.go b/storage/memory/roster_test.go similarity index 50% rename from storage/memstorage/roster_test.go rename to storage/memory/roster_test.go index c90669631..93d0a9180 100644 --- a/storage/memstorage/roster_test.go +++ b/storage/memory/roster_test.go @@ -3,12 +3,13 @@ * See the LICENSE file for more information. */ -package memstorage +package memorystorage import ( + "context" "testing" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/stretchr/testify/require" @@ -26,15 +27,15 @@ func TestMemoryStorage_InsertRosterItem(t *testing.T) { Groups: g, } - s := New() - s.EnableMockedError() - _, err := s.InsertOrUpdateRosterItem(&ri) - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - _, err = s.InsertOrUpdateRosterItem(&ri) + s := NewRoster() + EnableMockedError() + _, err := s.UpsertRosterItem(context.Background(), &ri) + require.Equal(t, ErrMocked, err) + DisableMockedError() + _, err = s.UpsertRosterItem(context.Background(), &ri) require.Nil(t, err) ri.Subscription = "to" - _, err = s.InsertOrUpdateRosterItem(&ri) + _, err = s.UpsertRosterItem(context.Background(), &ri) require.Nil(t, err) } @@ -49,18 +50,18 @@ func TestMemoryStorage_FetchRosterItem(t *testing.T) { Ver: 1, Groups: g, } - s := New() - _, _ = s.InsertOrUpdateRosterItem(&ri) + s := NewRoster() + _, _ = s.UpsertRosterItem(context.Background(), &ri) - s.EnableMockedError() - _, err := s.FetchRosterItem("user", "contact") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() + EnableMockedError() + _, err := s.FetchRosterItem(context.Background(), "user", "contact") + require.Equal(t, ErrMocked, err) + DisableMockedError() - ri3, _ := s.FetchRosterItem("user", "contact2") + ri3, _ := s.FetchRosterItem(context.Background(), "user", "contact2") require.Nil(t, ri3) - ri4, _ := s.FetchRosterItem("user", "contact") + ri4, _ := s.FetchRosterItem(context.Background(), "user", "contact") require.NotNil(t, ri4) require.Equal(t, "user", ri4.Username) require.Equal(t, "contact", ri4.JID) @@ -95,22 +96,30 @@ func TestMemoryStorage_FetchRosterItems(t *testing.T) { Groups: []string{"family", "friends"}, } - s := New() - _, _ = s.InsertOrUpdateRosterItem(&ri) - _, _ = s.InsertOrUpdateRosterItem(&ri2) - _, _ = s.InsertOrUpdateRosterItem(&ri3) + s := NewRoster() + _, _ = s.UpsertRosterItem(context.Background(), &ri) + _, _ = s.UpsertRosterItem(context.Background(), &ri2) + _, _ = s.UpsertRosterItem(context.Background(), &ri3) - s.EnableMockedError() - _, _, err := s.FetchRosterItems("user") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() + EnableMockedError() + _, _, err := s.FetchRosterItems(context.Background(), "user") + require.Equal(t, ErrMocked, err) + DisableMockedError() - ris, _, _ := s.FetchRosterItems("user") + ris, _, _ := s.FetchRosterItems(context.Background(), "user") require.Equal(t, 3, len(ris)) - ris, _, _ = s.FetchRosterItemsInGroups("user", []string{"friends"}) + ris, _, _ = s.FetchRosterItemsInGroups(context.Background(), "user", []string{"friends"}) require.Equal(t, 2, len(ris)) - ris, _, _ = s.FetchRosterItemsInGroups("user", []string{"buddies"}) + ris, _, _ = s.FetchRosterItemsInGroups(context.Background(), "user", []string{"buddies"}) require.Equal(t, 1, len(ris)) + + gr, _ := s.FetchRosterGroups(context.Background(), "user") + require.Len(t, gr, 4) + + require.Contains(t, gr, "general") + require.Contains(t, gr, "friends") + require.Contains(t, gr, "family") + require.Contains(t, gr, "buddies") } func TestMemoryStorage_DeleteRosterItem(t *testing.T) { @@ -124,21 +133,30 @@ func TestMemoryStorage_DeleteRosterItem(t *testing.T) { Ver: 1, Groups: g, } - s := New() - _, _ = s.InsertOrUpdateRosterItem(&ri) + s := NewRoster() + _, _ = s.UpsertRosterItem(context.Background(), &ri) + + gr, _ := s.FetchRosterGroups(context.Background(), "user") + require.Len(t, gr, 2) + + require.Contains(t, gr, "general") + require.Contains(t, gr, "friends") - s.EnableMockedError() - _, err := s.DeleteRosterItem("user", "contact") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() + EnableMockedError() + _, err := s.DeleteRosterItem(context.Background(), "user", "contact") + require.Equal(t, ErrMocked, err) + DisableMockedError() - _, err = s.DeleteRosterItem("user", "contact") + _, err = s.DeleteRosterItem(context.Background(), "user", "contact") require.Nil(t, err) - _, err = s.DeleteRosterItem("user2", "contact") + _, err = s.DeleteRosterItem(context.Background(), "user2", "contact") require.Nil(t, err) // delete not existing roster item... - ri2, _ := s.FetchRosterItem("user", "contact") + ri2, _ := s.FetchRosterItem(context.Background(), "user", "contact") require.Nil(t, ri2) + + gr, _ = s.FetchRosterGroups(context.Background(), "user") + require.Len(t, gr, 0) } func TestMemoryStorage_InsertRosterNotification(t *testing.T) { @@ -147,11 +165,11 @@ func TestMemoryStorage_InsertRosterNotification(t *testing.T) { JID: "romeo@jackal.im", Presence: &xmpp.Presence{}, } - s := New() - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.InsertOrUpdateRosterNotification(&rn)) - s.DisableMockedError() - require.Nil(t, s.InsertOrUpdateRosterNotification(&rn)) + s := NewRoster() + EnableMockedError() + require.Equal(t, ErrMocked, s.UpsertRosterNotification(context.Background(), &rn)) + DisableMockedError() + require.Nil(t, s.UpsertRosterNotification(context.Background(), &rn)) } func TestMemoryStorage_FetchRosterNotifications(t *testing.T) { @@ -165,20 +183,21 @@ func TestMemoryStorage_FetchRosterNotifications(t *testing.T) { JID: "ortuman2@jackal.im", Presence: &xmpp.Presence{}, } - s := New() - _ = s.InsertOrUpdateRosterNotification(&rn1) - _ = s.InsertOrUpdateRosterNotification(&rn2) + s := NewRoster() + _ = s.UpsertRosterNotification(context.Background(), &rn1) + _ = s.UpsertRosterNotification(context.Background(), &rn2) from, _ := jid.NewWithString("ortuman2@jackal.im", true) to, _ := jid.NewWithString("romeo@jackal.im", true) rn2.Presence = xmpp.NewPresence(from, to, xmpp.SubscribeType) - _ = s.InsertOrUpdateRosterNotification(&rn2) + _ = s.UpsertRosterNotification(context.Background(), &rn2) + + EnableMockedError() + _, err := s.FetchRosterNotifications(context.Background(), "romeo") + require.Equal(t, ErrMocked, err) + DisableMockedError() - s.EnableMockedError() - _, err := s.FetchRosterNotifications("romeo") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - rns, err := s.FetchRosterNotifications("romeo") + rns, err := s.FetchRosterNotifications(context.Background(), "romeo") require.Nil(t, err) require.Equal(t, 2, len(rns)) require.Equal(t, "ortuman@jackal.im", rns[0].JID) @@ -191,17 +210,19 @@ func TestMemoryStorage_DeleteRosterNotification(t *testing.T) { JID: "romeo@jackal.im", Presence: &xmpp.Presence{}, } - s := New() - _ = s.InsertOrUpdateRosterNotification(&rn1) + s := NewRoster() + _ = s.UpsertRosterNotification(context.Background(), &rn1) - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.DeleteRosterNotification("ortuman", "romeo@jackal.im")) - s.DisableMockedError() - require.Nil(t, s.DeleteRosterNotification("ortuman", "romeo@jackal.im")) + EnableMockedError() + require.Equal(t, ErrMocked, s.DeleteRosterNotification(context.Background(), "ortuman", "romeo@jackal.im")) + DisableMockedError() - rns, err := s.FetchRosterNotifications("romeo") + require.Nil(t, s.DeleteRosterNotification(context.Background(), "ortuman", "romeo@jackal.im")) + + rns, err := s.FetchRosterNotifications(context.Background(), "romeo") require.Nil(t, err) require.Equal(t, 0, len(rns)) + // delete not existing roster notification... - require.Nil(t, s.DeleteRosterNotification("ortuman2", "romeo@jackal.im")) + require.Nil(t, s.DeleteRosterNotification(context.Background(), "ortuman2", "romeo@jackal.im")) } diff --git a/storage/memory/storage.go b/storage/memory/storage.go new file mode 100644 index 000000000..945e84e84 --- /dev/null +++ b/storage/memory/storage.go @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "errors" + "sync" + + "github.com/ortuman/jackal/model/serializer" +) + +var ( + mockErrMu sync.RWMutex + mockErr bool + invokeLimit int32 + invokeCount int32 +) + +// ErrMocked represents in memory mocked error value. +var ErrMocked = errors.New("memstorage: mocked error") + +type memoryStorage struct { + mu sync.RWMutex + b map[string][]byte +} + +func newStorage() *memoryStorage { + return &memoryStorage{b: make(map[string][]byte)} +} + +// EnableMockedError enables in memory mocked error. +func EnableMockedError() { + EnableMockedErrorWithInvokeLimit(1) +} + +// EnableMockedErrorWithInvokeLimit enables in memory mocked error after a given invocation limit is reached. +func EnableMockedErrorWithInvokeLimit(limit int32) { + mockErrMu.Lock() + defer mockErrMu.Unlock() + mockErr = true + invokeLimit = limit + invokeCount = 0 +} + +// DisableMockedError disables in memory mocked error. +func DisableMockedError() { + mockErrMu.Lock() + defer mockErrMu.Unlock() + mockErr = false +} + +func (m *memoryStorage) inWriteLock(f func() error) error { + if err := checkMockedError(); err != nil { + return err + } + m.mu.Lock() + err := f() + m.mu.Unlock() + return err +} + +func (m *memoryStorage) inReadLock(f func() error) error { + if err := checkMockedError(); err != nil { + return err + } + m.mu.RLock() + err := f() + m.mu.RUnlock() + return err +} + +func (m *memoryStorage) saveEntity(k string, entity serializer.Serializer) error { + b, err := serializer.Serialize(entity) + if err != nil { + return err + } + return m.inWriteLock(func() error { + m.b[k] = b + return nil + }) +} + +func (m *memoryStorage) saveEntities(k string, entities interface{}) error { + b, err := serializer.SerializeSlice(entities) + if err != nil { + return err + } + return m.inWriteLock(func() error { + m.b[k] = b + return nil + }) +} + +func (m *memoryStorage) getEntity(k string, entity serializer.Deserializer) (bool, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[k] + return nil + }); err != nil { + return false, err + } + if b == nil { + return false, nil + } + if err := serializer.Deserialize(b, entity); err != nil { + return false, err + } + return true, nil +} + +func (m *memoryStorage) updateInWriteLock(k string, f func(b []byte) ([]byte, error)) error { + return m.inWriteLock(func() error { + b, err := f(m.b[k]) + if err != nil { + return err + } + m.b[k] = b + return nil + }) +} + +func (m *memoryStorage) getEntities(k string, entities interface{}) (bool, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[k] + return nil + }); err != nil { + return false, err + } + if b == nil { + return false, nil + } + if err := serializer.DeserializeSlice(b, entities); err != nil { + return false, err + } + return true, nil +} + +func (m *memoryStorage) deleteKey(k string) error { + return m.inWriteLock(func() error { + delete(m.b, k) + return nil + }) +} + +func (m *memoryStorage) keyExists(k string) (bool, error) { + var b []byte + if err := m.inReadLock(func() error { + b = m.b[k] + return nil + }); err != nil { + return false, err + } + return b != nil, nil +} + +func checkMockedError() error { + mockErrMu.Lock() + defer mockErrMu.Unlock() + + if mockErr { + invokeCount++ + if invokeCount >= invokeLimit { + return ErrMocked + } + } + return nil +} diff --git a/storage/memory/user.go b/storage/memory/user.go new file mode 100644 index 000000000..3e6199a60 --- /dev/null +++ b/storage/memory/user.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + "github.com/ortuman/jackal/model" +) + +// User represents an in-memory user storage. +type User struct { + *memoryStorage +} + +// NewUser returns an instance of User in-memory storage. +func NewUser() *User { + return &User{memoryStorage: newStorage()} +} + +// UpsertUser inserts a new user entity into storage, or updates it in case it's been previously inserted. +func (m *User) UpsertUser(_ context.Context, user *model.User) error { + return m.saveEntity(userKey(user.Username), user) +} + +// DeleteUser deletes a user entity from storage. +func (m *User) DeleteUser(_ context.Context, username string) error { + return m.deleteKey(userKey(username)) +} + +// FetchUser retrieves from storage a user entity. +func (m *User) FetchUser(_ context.Context, username string) (*model.User, error) { + var user model.User + ok, err := m.getEntity(userKey(username), &user) + switch err { + case nil: + if ok { + return &user, nil + } + return nil, nil + default: + return nil, err + } +} + +// UserExists returns whether or not a user exists within storage. +func (m *User) UserExists(_ context.Context, username string) (bool, error) { + return m.keyExists(userKey(username)) +} + +func userKey(username string) string { + return "users:" + username +} diff --git a/storage/memory/user_test.go b/storage/memory/user_test.go new file mode 100644 index 000000000..837459c0c --- /dev/null +++ b/storage/memory/user_test.go @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + "testing" + + "github.com/ortuman/jackal/model" + "github.com/stretchr/testify/require" +) + +func TestMemoryStorage_InsertUser(t *testing.T) { + u := model.User{Username: "ortuman", Password: "1234"} + s := NewUser() + EnableMockedError() + err := s.UpsertUser(context.Background(), &u) + require.Equal(t, ErrMocked, err) + DisableMockedError() + err = s.UpsertUser(context.Background(), &u) + require.Nil(t, err) +} + +func TestMemoryStorage_UserExists(t *testing.T) { + s := NewUser() + EnableMockedError() + _, err := s.UserExists(context.Background(), "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() + ok, err := s.UserExists(context.Background(), "ortuman") + require.Nil(t, err) + require.False(t, ok) +} + +func TestMemoryStorage_FetchUser(t *testing.T) { + u := model.User{Username: "ortuman", Password: "1234"} + s := NewUser() + _ = s.UpsertUser(context.Background(), &u) + + EnableMockedError() + _, err := s.FetchUser(context.Background(), "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() + + usr, _ := s.FetchUser(context.Background(), "romeo") + require.Nil(t, usr) + + usr, _ = s.FetchUser(context.Background(), "ortuman") + require.NotNil(t, usr) +} + +func TestMemoryStorage_DeleteUser(t *testing.T) { + u := model.User{Username: "ortuman", Password: "1234"} + s := NewUser() + _ = s.UpsertUser(context.Background(), &u) + + EnableMockedError() + require.Equal(t, ErrMocked, s.DeleteUser(context.Background(), "ortuman")) + DisableMockedError() + require.Nil(t, s.DeleteUser(context.Background(), "ortuman")) + + usr, _ := s.FetchUser(context.Background(), "ortuman") + require.Nil(t, usr) +} diff --git a/storage/memory/vcard.go b/storage/memory/vcard.go new file mode 100644 index 000000000..8d9397c4f --- /dev/null +++ b/storage/memory/vcard.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package memorystorage + +import ( + "context" + + "github.com/ortuman/jackal/xmpp" +) + +// VCard represents an in-memory vCard storage. +type VCard struct { + *memoryStorage +} + +// NewVCard returns an instance of VCard in-memory storage. +func NewVCard() *VCard { + return &VCard{memoryStorage: newStorage()} +} + +// UpsertVCard inserts a new vCard element into storage, or updates it in case it's been previously inserted. +func (m *VCard) UpsertVCard(_ context.Context, vCard xmpp.XElement, username string) error { + return m.saveEntity(vCardKey(username), vCard) +} + +// FetchVCard retrieves from storage a vCard element associated to a given user. +func (m *VCard) FetchVCard(_ context.Context, username string) (xmpp.XElement, error) { + var vCard xmpp.Element + ok, err := m.getEntity(vCardKey(username), &vCard) + switch err { + case nil: + if ok { + return &vCard, nil + } + return nil, nil + default: + return nil, err + } +} + +func vCardKey(username string) string { + return "vCards:" + username +} diff --git a/storage/memstorage/vcard_test.go b/storage/memory/vcard_test.go similarity index 53% rename from storage/memstorage/vcard_test.go rename to storage/memory/vcard_test.go index a5614eb36..815e81e05 100644 --- a/storage/memstorage/vcard_test.go +++ b/storage/memory/vcard_test.go @@ -3,9 +3,10 @@ * See the LICENSE file for more information. */ -package memstorage +package memorystorage import ( + "context" "testing" "github.com/ortuman/jackal/xmpp" @@ -18,11 +19,11 @@ func TestMemoryStorage_InsertVCard(t *testing.T) { fn.SetText("Miguel Ɓngel") vCard.AppendElement(fn) - s := New() - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.InsertOrUpdateVCard(vCard, "ortuman")) - s.DisableMockedError() - require.Nil(t, s.InsertOrUpdateVCard(vCard, "ortuman")) + s := NewVCard() + EnableMockedError() + require.Equal(t, ErrMocked, s.UpsertVCard(context.Background(), vCard, "ortuman")) + DisableMockedError() + require.Nil(t, s.UpsertVCard(context.Background(), vCard, "ortuman")) } func TestMemoryStorage_FetchVCard(t *testing.T) { @@ -31,14 +32,14 @@ func TestMemoryStorage_FetchVCard(t *testing.T) { fn.SetText("Miguel Ɓngel") vCard.AppendElement(fn) - s := New() - _ = s.InsertOrUpdateVCard(vCard, "ortuman") + s := NewVCard() + _ = s.UpsertVCard(context.Background(), vCard, "ortuman") - s.EnableMockedError() - _, err := s.FetchVCard("ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() + EnableMockedError() + _, err := s.FetchVCard(context.Background(), "ortuman") + require.Equal(t, ErrMocked, err) + DisableMockedError() - elem, _ := s.FetchVCard("ortuman") + elem, _ := s.FetchVCard(context.Background(), "ortuman") require.NotNil(t, elem) } diff --git a/storage/memstorage/block_list.go b/storage/memstorage/block_list.go deleted file mode 100644 index 9b3269b65..000000000 --- a/storage/memstorage/block_list.go +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/serializer" -) - -// InsertBlockListItems inserts a set of block list item entities -// into storage, only in case they haven't been previously inserted. -func (m *Storage) InsertBlockListItems(items []model.BlockListItem) error { - return m.inWriteLock(func() error { - for _, item := range items { - blItems, err := m.fetchUserBlockListItems(item.Username) - if err != nil { - return err - } - if blItems != nil { - for _, blItem := range blItems { - if blItem.JID == item.JID { - goto done - } - } - blItems = append(blItems, item) - } else { - blItems = []model.BlockListItem{item} - } - if err := m.upsertBlockListItems(blItems, item.Username); err != nil { - return err - } - done: - } - return nil - }) -} - -// DeleteBlockListItems deletes a set of block list item entities from storage. -func (m *Storage) DeleteBlockListItems(items []model.BlockListItem) error { - return m.inWriteLock(func() error { - for _, itm := range items { - blItems, err := m.fetchUserBlockListItems(itm.Username) - if err != nil { - return err - } - for i, blItem := range blItems { - if blItem.JID == itm.JID { - // delete item - blItems = append(blItems[:i], blItems[i+1:]...) - if err := m.upsertBlockListItems(blItems, itm.Username); err != nil { - return err - } - break - } - } - } - return nil - }) -} - -// FetchBlockListItems retrieves from storage all block list item entities -// associated to a given user. -func (m *Storage) FetchBlockListItems(username string) ([]model.BlockListItem, error) { - var blItems []model.BlockListItem - if err := m.inReadLock(func() error { - var fnErr error - blItems, fnErr = m.fetchUserBlockListItems(username) - return fnErr - }); err != nil { - return nil, err - } - return blItems, nil -} - -func (m *Storage) upsertBlockListItems(blItems []model.BlockListItem, username string) error { - b, err := serializer.SerializeSlice(&blItems) - if err != nil { - return err - } - m.bytes[blockListItemKey(username)] = b - return nil -} - -func (m *Storage) fetchUserBlockListItems(username string) ([]model.BlockListItem, error) { - b := m.bytes[blockListItemKey(username)] - if b == nil { - return nil, nil - } - var blItems []model.BlockListItem - if err := serializer.DeserializeSlice(b, &blItems); err != nil { - return nil, err - } - return blItems, nil -} - -func blockListItemKey(username string) string { - return "blockListItems:" + username -} diff --git a/storage/memstorage/block_list_test.go b/storage/memstorage/block_list_test.go deleted file mode 100644 index d57210040..000000000 --- a/storage/memstorage/block_list_test.go +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "testing" - - "github.com/ortuman/jackal/model" - "github.com/stretchr/testify/require" -) - -func TestMemoryStorage_InsertOrUpdateBlockListItems(t *testing.T) { - items := []model.BlockListItem{ - {Username: "ortuman", JID: "user@jackal.im"}, - {Username: "ortuman", JID: "romeo@jackal.im"}, - {Username: "ortuman", JID: "juliet@jackal.im"}, - } - s := New() - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.InsertBlockListItems(items)) - s.DisableMockedError() - - _ = s.InsertBlockListItems(items) - - s.EnableMockedError() - _, err := s.FetchBlockListItems("ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - - sItems, _ := s.FetchBlockListItems("ortuman") - require.Equal(t, items, sItems) -} - -func TestMemoryStorage_DeleteBlockListItems(t *testing.T) { - items := []model.BlockListItem{ - {Username: "ortuman", JID: "user@jackal.im"}, - {Username: "ortuman", JID: "romeo@jackal.im"}, - {Username: "ortuman", JID: "juliet@jackal.im"}, - } - s := New() - _ = s.InsertBlockListItems(items) - - delItems := []model.BlockListItem{{Username: "ortuman", JID: "romeo@jackal.im"}} - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.DeleteBlockListItems(delItems)) - s.DisableMockedError() - - _ = s.DeleteBlockListItems(delItems) - sItems, _ := s.FetchBlockListItems("ortuman") - require.Equal(t, []model.BlockListItem{ - {Username: "ortuman", JID: "user@jackal.im"}, - {Username: "ortuman", JID: "juliet@jackal.im"}, - }, sItems) -} diff --git a/storage/memstorage/memstorage.go b/storage/memstorage/memstorage.go deleted file mode 100644 index 6c237b5ca..000000000 --- a/storage/memstorage/memstorage.go +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "errors" - "sync" -) - -// ErrMockedError will be returned by any Storage method -// when mocked error is activated. -var ErrMockedError = errors.New("memstorage: mocked error") - -// Storage represents an in memory storage sub system. -type Storage struct { - mockErrMu sync.Mutex - mockingErr bool - invokeLimit int32 - invokeCount int32 - mu sync.RWMutex - bytes map[string][]byte -} - -// New returns a new in memory storage instance. -func New() *Storage { - return &Storage{ - bytes: make(map[string][]byte), - } -} - -// IsClusterCompatible returns whether or not the underlying storage subsystem can be used in cluster mode. -func (m *Storage) IsClusterCompatible() bool { return false } - -// Close shuts down in memory storage sub system. -func (m *Storage) Close() error { - return nil -} - -// EnableMockedError enables in memory mocked error. -func (m *Storage) EnableMockedError() { - m.EnableMockedErrorWithInvokeLimit(1) -} - -// EnableMockedErrorWithInvokeLimit enables in memory mocked error after a given invocation limit is reached. -func (m *Storage) EnableMockedErrorWithInvokeLimit(invokeLimit int32) { - m.mockErrMu.Lock() - defer m.mockErrMu.Unlock() - m.mockingErr = true - m.invokeLimit = invokeLimit - m.invokeCount = 0 -} - -// DisableMockedError disables in memory mocked error. -func (m *Storage) DisableMockedError() { - m.mockErrMu.Lock() - defer m.mockErrMu.Unlock() - m.mockingErr = false -} - -func (m *Storage) inWriteLock(f func() error) error { - m.mockErrMu.Lock() - defer m.mockErrMu.Unlock() - m.invokeCount++ - if m.invokeCount == m.invokeLimit { - return ErrMockedError - } - m.mu.Lock() - err := f() - m.mu.Unlock() - return err -} - -func (m *Storage) inReadLock(f func() error) error { - m.mockErrMu.Lock() - defer m.mockErrMu.Unlock() - m.invokeCount++ - if m.invokeCount == m.invokeLimit { - return ErrMockedError - } - m.mu.RLock() - err := f() - m.mu.RUnlock() - return err -} diff --git a/storage/memstorage/offline.go b/storage/memstorage/offline.go deleted file mode 100644 index a8076e57a..000000000 --- a/storage/memstorage/offline.go +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "github.com/ortuman/jackal/model/serializer" - "github.com/ortuman/jackal/xmpp" -) - -// InsertOfflineMessage inserts a new message element into user's offline queue. -func (m *Storage) InsertOfflineMessage(message *xmpp.Message, username string) error { - return m.inWriteLock(func() error { - messages, err := m.fetchUserOfflineMessages(username) - if err != nil { - return err - } - messages = append(messages, *message) - - b, err := serializer.SerializeSlice(&messages) - if err != nil { - return err - } - m.bytes[offlineMessageKey(username)] = b - return nil - }) -} - -// CountOfflineMessages returns current length of user's offline queue. -func (m *Storage) CountOfflineMessages(username string) (int, error) { - var messages []xmpp.Message - if err := m.inReadLock(func() error { - var fnErr error - messages, fnErr = m.fetchUserOfflineMessages(username) - return fnErr - }); err != nil { - return 0, err - } - return len(messages), nil -} - -// FetchOfflineMessages retrieves from storage current user offline queue. -func (m *Storage) FetchOfflineMessages(username string) ([]xmpp.Message, error) { - var messages []xmpp.Message - if err := m.inReadLock(func() error { - var fnErr error - messages, fnErr = m.fetchUserOfflineMessages(username) - return fnErr - }); err != nil { - return nil, err - } - return messages, nil -} - -// DeleteOfflineMessages clears a user offline queue. -func (m *Storage) DeleteOfflineMessages(username string) error { - return m.inWriteLock(func() error { - delete(m.bytes, offlineMessageKey(username)) - return nil - }) -} - -func (m *Storage) fetchUserOfflineMessages(username string) ([]xmpp.Message, error) { - b := m.bytes[offlineMessageKey(username)] - if b == nil { - return nil, nil - } - var messages []xmpp.Message - if err := serializer.DeserializeSlice(b, &messages); err != nil { - return nil, err - } - return messages, nil -} - -func offlineMessageKey(username string) string { - return "offlineMessages:" + username -} diff --git a/storage/memstorage/private.go b/storage/memstorage/private.go deleted file mode 100644 index a0caf62fe..000000000 --- a/storage/memstorage/private.go +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "github.com/ortuman/jackal/model/serializer" - "github.com/ortuman/jackal/xmpp" -) - -// InsertOrUpdatePrivateXML inserts a new private element into storage, -// or updates it in case it's been previously inserted. -func (m *Storage) InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace string, username string) error { - var priv []xmpp.Element - - // convert to concrete type - for _, el := range privateXML { - priv = append(priv, *xmpp.NewElementFromElement(el)) - } - b, err := serializer.SerializeSlice(&priv) - if err != nil { - return err - } - return m.inWriteLock(func() error { - m.bytes[privateStorageKey(username, namespace)] = b - return nil - }) -} - -// FetchPrivateXML retrieves from storage a private element. -func (m *Storage) FetchPrivateXML(namespace string, username string) ([]xmpp.XElement, error) { - var b []byte - if err := m.inReadLock(func() error { - b = m.bytes[privateStorageKey(username, namespace)] - return nil - }); err != nil { - return nil, err - } - if b == nil { - return nil, nil - } - var priv []xmpp.Element - if err := serializer.DeserializeSlice(b, &priv); err != nil { - return nil, err - } - var ret []xmpp.XElement - for _, e := range priv { - ret = append(ret, &e) - } - return ret, nil -} - -func privateStorageKey(username, namespace string) string { - return "privateElements:" + username + ":" + namespace -} diff --git a/storage/memstorage/private_test.go b/storage/memstorage/private_test.go deleted file mode 100644 index 8dbcee413..000000000 --- a/storage/memstorage/private_test.go +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "testing" - - "github.com/ortuman/jackal/xmpp" - "github.com/stretchr/testify/require" -) - -func TestMemoryStorage_InsertPrivateXML(t *testing.T) { - private := xmpp.NewElementNamespace("exodus", "exodus:ns") - - s := New() - s.EnableMockedError() - err := s.InsertOrUpdatePrivateXML([]xmpp.XElement{private}, "exodus:ns", "ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - err = s.InsertOrUpdatePrivateXML([]xmpp.XElement{private}, "exodus:ns", "ortuman") - require.Nil(t, err) -} - -func TestMemoryStorage_FetchPrivateXML(t *testing.T) { - private := xmpp.NewElementNamespace("exodus", "exodus:ns") - - s := New() - _ = s.InsertOrUpdatePrivateXML([]xmpp.XElement{private}, "exodus:ns", "ortuman") - - s.EnableMockedError() - _, err := s.FetchPrivateXML("exodus:ns", "ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - elems, _ := s.FetchPrivateXML("exodus:ns", "ortuman") - require.Equal(t, 1, len(elems)) -} diff --git a/storage/memstorage/user.go b/storage/memstorage/user.go deleted file mode 100644 index f130859f9..000000000 --- a/storage/memstorage/user.go +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "github.com/ortuman/jackal/model" - "github.com/ortuman/jackal/model/serializer" -) - -// InsertOrUpdateUser inserts a new user entity into storage, -// or updates it in case it's been previously inserted. -func (m *Storage) InsertOrUpdateUser(user *model.User) error { - b, err := serializer.Serialize(user) - if err != nil { - return err - } - return m.inWriteLock(func() error { - m.bytes[userKey(user.Username)] = b - return nil - }) -} - -// DeleteUser deletes a user entity from storage. -func (m *Storage) DeleteUser(username string) error { - return m.inWriteLock(func() error { - delete(m.bytes, userKey(username)) - return nil - }) -} - -// FetchUser retrieves from storage a user entity. -func (m *Storage) FetchUser(username string) (*model.User, error) { - var b []byte - if err := m.inReadLock(func() error { - b = m.bytes[userKey(username)] - return nil - }); err != nil { - return nil, err - } - if b == nil { - return nil, nil - } - var usr model.User - if err := serializer.Deserialize(b, &usr); err != nil { - return nil, err - } - return &usr, nil -} - -// UserExists returns whether or not a user exists within storage. -func (m *Storage) UserExists(username string) (bool, error) { - var b []byte - if err := m.inReadLock(func() error { - b = m.bytes[userKey(username)] - return nil - }); err != nil { - return false, err - } - return b != nil, nil -} - -func userKey(username string) string { - return "users:" + username -} diff --git a/storage/memstorage/user_test.go b/storage/memstorage/user_test.go deleted file mode 100644 index 4eafad6ca..000000000 --- a/storage/memstorage/user_test.go +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "testing" - - "github.com/ortuman/jackal/model" - "github.com/stretchr/testify/require" -) - -func TestMemoryStorage_InsertUser(t *testing.T) { - u := model.User{Username: "ortuman", Password: "1234"} - s := New() - s.EnableMockedError() - err := s.InsertOrUpdateUser(&u) - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - err = s.InsertOrUpdateUser(&u) - require.Nil(t, err) -} - -func TestMemoryStorage_UserExists(t *testing.T) { - s := New() - s.EnableMockedError() - _, err := s.UserExists("ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - ok, err := s.UserExists("ortuman") - require.Nil(t, err) - require.False(t, ok) -} - -func TestMemoryStorage_FetchUser(t *testing.T) { - u := model.User{Username: "ortuman", Password: "1234"} - s := New() - _ = s.InsertOrUpdateUser(&u) - - s.EnableMockedError() - _, err := s.FetchUser("ortuman") - require.Equal(t, ErrMockedError, err) - s.DisableMockedError() - - usr, _ := s.FetchUser("romeo") - require.Nil(t, usr) - - usr, _ = s.FetchUser("ortuman") - require.NotNil(t, usr) -} - -func TestMemoryStorage_DeleteUser(t *testing.T) { - u := model.User{Username: "ortuman", Password: "1234"} - s := New() - _ = s.InsertOrUpdateUser(&u) - - s.EnableMockedError() - require.Equal(t, ErrMockedError, s.DeleteUser("ortuman")) - s.DisableMockedError() - require.Nil(t, s.DeleteUser("ortuman")) - - usr, _ := s.FetchUser("ortuman") - require.Nil(t, usr) -} diff --git a/storage/memstorage/vcard.go b/storage/memstorage/vcard.go deleted file mode 100644 index 6979f8571..000000000 --- a/storage/memstorage/vcard.go +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package memstorage - -import ( - "github.com/ortuman/jackal/model/serializer" - "github.com/ortuman/jackal/xmpp" -) - -// InsertOrUpdateVCard inserts a new vCard element into storage, -// or updates it in case it's been previously inserted. -func (m *Storage) InsertOrUpdateVCard(vCard xmpp.XElement, username string) error { - b, err := serializer.Serialize(vCard) - if err != nil { - return err - } - return m.inWriteLock(func() error { - m.bytes[vCardKey(username)] = b - return nil - }) -} - -// FetchVCard retrieves from storage a vCard element associated -// to a given user. -func (m *Storage) FetchVCard(username string) (xmpp.XElement, error) { - var b []byte - if err := m.inReadLock(func() error { - b = m.bytes[vCardKey(username)] - return nil - }); err != nil { - return nil, err - } - if b == nil { - return nil, nil - } - var vCard xmpp.Element - if err := serializer.Deserialize(b, &vCard); err != nil { - return nil, err - } - return &vCard, nil -} - -func vCardKey(username string) string { - return "vCards:" + username -} diff --git a/storage/mysql/block_list.go b/storage/mysql/block_list.go index 3c94af51b..c7873cdff 100644 --- a/storage/mysql/block_list.go +++ b/storage/mysql/block_list.go @@ -6,66 +6,60 @@ package mysql import ( + "context" "database/sql" sq "github.com/Masterminds/squirrel" "github.com/ortuman/jackal/model" ) -// InsertBlockListItems inserts a set of block list item entities -// into storage, only in case they haven't been previously inserted. -func (s *Storage) InsertBlockListItems(items []model.BlockListItem) error { - return s.inTransaction(func(tx *sql.Tx) error { - for _, item := range items { - _, err := sq.Insert("blocklist_items"). - Options("IGNORE"). - Columns("username", "jid", "created_at"). - Values(item.Username, item.JID, nowExpr). - RunWith(tx).Exec() - if err != nil { - return err - } - } - return nil - }) +type mySQLBlockList struct { + *mySQLStorage } -// DeleteBlockListItems deletes a set of block list item entities from storage. -func (s *Storage) DeleteBlockListItems(items []model.BlockListItem) error { - return s.inTransaction(func(tx *sql.Tx) error { - for _, item := range items { - _, err := sq.Delete("blocklist_items"). - Where(sq.And{sq.Eq{"username": item.Username}, sq.Eq{"jid": item.JID}}). - RunWith(tx).Exec() - if err != nil { - return err - } - } - return nil - }) +func newBlockList(db *sql.DB) *mySQLBlockList { + return &mySQLBlockList{ + mySQLStorage: newStorage(db), + } +} + +func (s *mySQLBlockList) InsertBlockListItem(ctx context.Context, item *model.BlockListItem) error { + _, err := sq.Insert("blocklist_items"). + Options("IGNORE"). + Columns("username", "jid", "created_at"). + Values(item.Username, item.JID, nowExpr). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *mySQLBlockList) DeleteBlockListItem(ctx context.Context, item *model.BlockListItem) error { + _, err := sq.Delete("blocklist_items"). + Where(sq.And{sq.Eq{"username": item.Username}, sq.Eq{"jid": item.JID}}). + RunWith(s.db).ExecContext(ctx) + return err } -// FetchBlockListItems retrieves from storage all block list item entities -// associated to a given user. -func (s *Storage) FetchBlockListItems(username string) ([]model.BlockListItem, error) { +func (s *mySQLBlockList) FetchBlockListItems(ctx context.Context, username string) ([]model.BlockListItem, error) { q := sq.Select("username", "jid"). From("blocklist_items"). Where(sq.Eq{"username": username}). OrderBy("created_at") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, err } - defer rows.Close() - return s.scanBlockListItemEntities(rows) + defer func() { _ = rows.Close() }() + return scanBlockListItemEntities(rows) } -func (s *Storage) scanBlockListItemEntities(scanner rowsScanner) ([]model.BlockListItem, error) { +func scanBlockListItemEntities(scanner rowsScanner) ([]model.BlockListItem, error) { var ret []model.BlockListItem for scanner.Next() { var it model.BlockListItem - scanner.Scan(&it.Username, &it.JID) + if err := scanner.Scan(&it.Username, &it.JID); err != nil { + return nil, err + } ret = append(ret, it) } return ret, nil diff --git a/storage/mysql/block_list_test.go b/storage/mysql/block_list_test.go index 1c5292764..0f1946d65 100644 --- a/storage/mysql/block_list_test.go +++ b/storage/mysql/block_list_test.go @@ -6,6 +6,7 @@ package mysql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" @@ -14,74 +15,70 @@ import ( ) func TestMySQLStorageInsertBlockListItems(t *testing.T) { - s, mock := NewMock() - mock.ExpectBegin() + s, mock := newBlockListMock() mock.ExpectExec("INSERT IGNORE INTO blocklist_items (.+)"). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - err := s.InsertBlockListItems([]model.BlockListItem{{Username: "ortuman", JID: "noelia@jackal.im"}}) + err := s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() - mock.ExpectBegin() + s, mock = newBlockListMock() mock.ExpectExec("INSERT IGNORE INTO blocklist_items (.+)").WillReturnError(errMySQLStorage) - mock.ExpectRollback() - err = s.InsertBlockListItems([]model.BlockListItem{{Username: "ortuman", JID: "noelia@jackal.im"}}) + err = s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } func TestMySQLFetchBlockListItems(t *testing.T) { var blockListColumns = []string{"username", "jid"} - s, mock := NewMock() + s, mock := newBlockListMock() mock.ExpectQuery("SELECT (.+) FROM blocklist_items (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(blockListColumns).AddRow("ortuman", "noelia@jackal.im")) - _, err := s.FetchBlockListItems("ortuman") + _, err := s.FetchBlockListItems(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newBlockListMock() mock.ExpectQuery("SELECT (.+) FROM blocklist_items (.+)"). WithArgs("ortuman"). WillReturnError(errMySQLStorage) - _, err = s.FetchBlockListItems("ortuman") + _, err = s.FetchBlockListItems(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } func TestMySQLStorageDeleteBlockListItems(t *testing.T) { - s, mock := NewMock() - mock.ExpectBegin() + s, mock := newBlockListMock() mock.ExpectExec("DELETE FROM blocklist_items (.+)"). WithArgs("ortuman"). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - s, mock = NewMock() - mock.ExpectBegin() + s, mock = newBlockListMock() mock.ExpectExec("DELETE FROM blocklist_items (.+)"). WithArgs("ortuman", "noelia@jackal.im"). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - delItems := []model.BlockListItem{{Username: "ortuman", JID: "noelia@jackal.im"}} - err := s.DeleteBlockListItems(delItems) + err := s.DeleteBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() - mock.ExpectBegin() + s, mock = newBlockListMock() mock.ExpectExec("DELETE FROM blocklist_items (.+)"). WillReturnError(errMySQLStorage) - mock.ExpectRollback() - err = s.DeleteBlockListItems(delItems) + err = s.DeleteBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } + +func newBlockListMock() (*mySQLBlockList, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLBlockList{ + mySQLStorage: s, + }, sqlMock +} diff --git a/storage/mysql/config.go b/storage/mysql/config.go index cf558528d..8ebaf27fc 100644 --- a/storage/mysql/config.go +++ b/storage/mysql/config.go @@ -1,3 +1,8 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + package mysql // DefaultPoolSize defines the default size of MySQL connection pool @@ -21,7 +26,6 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { if err := unmarshal(&parsed); err != nil { return err } - *c = Config(parsed) return nil diff --git a/storage/mysql/mysql.go b/storage/mysql/mysql.go new file mode 100644 index 000000000..cf2cbf01e --- /dev/null +++ b/storage/mysql/mysql.go @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "database/sql" + "fmt" + "time" + + _ "github.com/go-sql-driver/mysql" // SQL driver + "github.com/ortuman/jackal/log" + "github.com/ortuman/jackal/storage/repository" +) + +type mySQLContainer struct { + user *mySQLUser + roster *mySQLRoster + presences *mySQLPresences + vCard *mySQLVCard + priv *mySQLPrivate + blockList *mySQLBlockList + pubSub *mySQLPubSub + offline *mySQLOffline + room *mySQLRoom + occupant *mySQLOccupant + + h *sql.DB + doneCh chan chan bool +} + +// New initializes MySQL storage and returns associated container. +func New(cfg *Config) (repository.Container, error) { + var err error + c := &mySQLContainer{doneCh: make(chan chan bool, 1)} + host := cfg.Host + usr := cfg.User + pass := cfg.Password + db := cfg.Database + poolSize := cfg.PoolSize + + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true", usr, pass, host, db) + c.h, err = sql.Open("mysql", dsn) + if err != nil { + return nil, err + } + c.h.SetMaxOpenConns(poolSize) // set max opened connection count + + if err := c.h.Ping(); err != nil { + return nil, err + } + go c.loop() + + c.user = newUser(c.h) + c.roster = newRoster(c.h) + c.presences = newPresences(c.h) + c.vCard = newVCard(c.h) + c.priv = newPrivate(c.h) + c.blockList = newBlockList(c.h) + c.pubSub = newPubSub(c.h) + c.offline = newOffline(c.h) + c.occupant = newOccupant(c.h) + c.room = newRoom(c.h) + + return c, nil +} + +func (c *mySQLContainer) User() repository.User { return c.user } +func (c *mySQLContainer) Roster() repository.Roster { return c.roster } +func (c *mySQLContainer) Presences() repository.Presences { return c.presences } +func (c *mySQLContainer) VCard() repository.VCard { return c.vCard } +func (c *mySQLContainer) Private() repository.Private { return c.priv } +func (c *mySQLContainer) BlockList() repository.BlockList { return c.blockList } +func (c *mySQLContainer) PubSub() repository.PubSub { return c.pubSub } +func (c *mySQLContainer) Offline() repository.Offline { return c.offline } +func (c *mySQLContainer) Room() repository.Room { return c.room } +func (c *mySQLContainer) Occupant() repository.Occupant { return c.occupant } + +func (c *mySQLContainer) Close(ctx context.Context) error { + ch := make(chan bool) + c.doneCh <- ch + select { + case <-ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *mySQLContainer) IsClusterCompatible() bool { return true } + +func (c *mySQLContainer) loop() { + tc := time.NewTicker(time.Second * 15) + defer tc.Stop() + + for { + select { + case <-tc.C: + if err := c.h.Ping(); err != nil { + log.Error(err) + } + case ch := <-c.doneCh: + if err := c.h.Close(); err != nil { + log.Error(err) + } + close(ch) + return + } + } +} diff --git a/storage/mysql/mysql_test.go b/storage/mysql/mysql_test.go new file mode 100644 index 000000000..d09e0620f --- /dev/null +++ b/storage/mysql/mysql_test.go @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import "errors" + +var ( + errMySQLStorage = errors.New("mysql: storage error") +) diff --git a/storage/mysql/occupant.go b/storage/mysql/occupant.go new file mode 100644 index 000000000..d924f9cc8 --- /dev/null +++ b/storage/mysql/occupant.go @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "database/sql" + + sq "github.com/Masterminds/squirrel" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +type mySQLOccupant struct { + *mySQLStorage +} + +func newOccupant(db *sql.DB) *mySQLOccupant { + return &mySQLOccupant{ + mySQLStorage: newStorage(db), + } +} + +func (o *mySQLOccupant) UpsertOccupant(ctx context.Context, occ *mucmodel.Occupant) error { + return o.inTransaction(ctx, func(tx *sql.Tx) error { + // store occupants data (except for resources) + columns := []string{"occupant_jid", "bare_jid", "affiliation", "role"} + values := []interface{}{occ.OccupantJID.String(), occ.BareJID.String(), + occ.GetAffiliation(), occ.GetRole()} + q := sq.Insert("occupants"). + Columns(columns...). + Values(values...). + Suffix("ON DUPLICATE KEY UPDATE affiliation = ?, role = ?", occ.GetAffiliation(), + occ.GetRole()) + + _, err := q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + //store occupants resources + columns = []string{"occupant_jid", "resource"} + for _, res := range occ.GetAllResources() { + values = []interface{}{occ.OccupantJID.String(), res} + q = sq.Insert("resources"). + Columns(columns...). + Values(values...). + Suffix("ON DUPLICATE KEY UPDATE resource = ?", res) + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + return nil + }) +} + +func (o *mySQLOccupant) DeleteOccupant(ctx context.Context, occJID *jid.JID) error { + return o.inTransaction(ctx, func(tx *sql.Tx) error { + _, err := sq.Delete("occupants").Where(sq.Eq{"occupant_jid": occJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("resources").Where(sq.Eq{"occupant_jid": occJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + return nil + }) +} + +func (o *mySQLOccupant) FetchOccupant(ctx context.Context, occJID *jid.JID) (*mucmodel.Occupant, + error) { + tx, err := o.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + occ, err := fetchOccupantData(ctx, tx, occJID) + switch err { + case nil: + case sql.ErrNoRows: + _ = tx.Commit() + return nil, nil + default: + _ = tx.Rollback() + return nil, err + + } + + err = fetchOccupantResources(ctx, tx, occ, occJID) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return occ, nil +} + +func fetchOccupantData(ctx context.Context, tx *sql.Tx, occJID *jid.JID) (*mucmodel.Occupant, + error) { + var occ *mucmodel.Occupant + q := sq.Select("occupant_jid", "bare_jid", "affiliation", "role"). + From("occupants"). + Where(sq.Eq{"occupant_jid": occJID.String()}) + + var occJIDStr, bareJIDStr, affiliation, role string + err := q.RunWith(tx). + QueryRowContext(ctx). + Scan(&occJIDStr, &bareJIDStr, &affiliation, &role) + switch err { + case nil: + occJIDdb, err := jid.NewWithString(occJIDStr, false) + if err != nil { + return nil, err + } + bareJID, err := jid.NewWithString(bareJIDStr, false) + if err != nil { + return nil, err + } + occ, err = mucmodel.NewOccupant(occJIDdb, bareJID) + if err != nil { + return nil, err + } + err = occ.SetAffiliation(affiliation) + if err != nil { + return nil, err + } + err = occ.SetRole(role) + if err != nil { + return nil, err + } + default: + return nil, err + } + return occ, nil +} + +func fetchOccupantResources(ctx context.Context, tx *sql.Tx, occ *mucmodel.Occupant, + occJID *jid.JID) error { + resources, err := sq.Select("occupant_jid", "resource"). + From("resources"). + Where(sq.Eq{"occupant_jid": occJID.String()}). + RunWith(tx).QueryContext(ctx) + if err != nil { + return err + } + for resources.Next() { + var dummy, res string + if err = resources.Scan(&dummy, &res); err != nil { + return err + } + occ.AddResource(res) + } + return nil +} + +func (o *mySQLOccupant) OccupantExists(ctx context.Context, occJID *jid.JID) (bool, error) { + q := sq.Select("COUNT(*)"). + From("occupants"). + Where(sq.Eq{"occupant_jid": occJID.String()}) + + var count int + err := q.RunWith(o.db).QueryRowContext(ctx).Scan(&count) + switch err { + case nil: + return count > 0, nil + default: + return false, err + } +} diff --git a/storage/mysql/occupant_test.go b/storage/mysql/occupant_test.go new file mode 100644 index 000000000..c807be67a --- /dev/null +++ b/storage/mysql/occupant_test.go @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestMySQLStorageInsertOccupant(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + o, _ := mucmodel.NewOccupant(j, j.ToBareJID()) + o.AddResource("yard") + o.SetAffiliation("owner") + o.SetRole("moderator") + + s, mock := newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO occupants (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(o.OccupantJID.String(), o.BareJID.String(), o.GetAffiliation(), o.GetRole(), + o.GetAffiliation(), o.GetRole()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO resources (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(o.OccupantJID.String(), "yard", "yard"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := s.UpsertOccupant(context.Background(), o) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO occupants (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(o.OccupantJID.String(), o.BareJID.String(), o.GetAffiliation(), o.GetRole(), + o.GetAffiliation(), o.GetRole()). + WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.UpsertOccupant(context.Background(), o) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, err, errMocked) +} + +func TestMySQLStorageDeleteOccupant(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + s, mock := newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM occupants (.+)"). + WithArgs(j.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM resources (.+)"). + WithArgs(j.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.DeleteOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM occupants (.+)"). + WithArgs(j.String()).WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.DeleteOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestMySQLStorageFetchOccupant(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + + occColumns := []string{"occupant_jid", "bare_jid", "affiliation", "role"} + resColumns := []string{"occupant_jid", "resource"} + + s, mock := newOccupantMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(occColumns)) + mock.ExpectCommit() + + occ, _ := s.FetchOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, occ) + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(occColumns). + AddRow(j.String(), j.ToBareJID().String(), "owner", "moderator")) + mock.ExpectQuery("SELECT (.+) FROM resources (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(resColumns). + AddRow(j.String(), "phone")) + mock.ExpectCommit() + occ, err := s.FetchOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.NotNil(t, occ) + require.Equal(t, occ.OccupantJID.String(), j.String()) + require.Equal(t, occ.BareJID.String(), j.ToBareJID().String()) + require.Equal(t, occ.GetAffiliation(), "owner") + require.Equal(t, occ.GetRole(), "moderator") + require.Len(t, occ.GetAllResources(), 1) + require.Equal(t, occ.GetAllResources()[0], "phone") + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM occupants (.+)"). + WithArgs(j.String()).WillReturnError(errMocked) + mock.ExpectRollback() + _, err = s.FetchOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestMySQLStorageOccupantExists(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + countCols := []string{"count"} + + s, mock := newOccupantMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(countCols).AddRow(1)) + + ok, err := s.OccupantExists(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.True(t, ok) + + s, mock = newOccupantMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnError(errMocked) + _, err = s.OccupantExists(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func newOccupantMock() (*mySQLOccupant, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLOccupant{ + mySQLStorage: s, + }, sqlMock +} diff --git a/storage/mysql/offline.go b/storage/mysql/offline.go index a3bf2b36e..901e1b3ff 100644 --- a/storage/mysql/offline.go +++ b/storage/mysql/offline.go @@ -6,30 +6,43 @@ package mysql import ( + "context" + "database/sql" + sq "github.com/Masterminds/squirrel" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) -// InsertOfflineMessage inserts a new message element into -// user's offline queue. -func (s *Storage) InsertOfflineMessage(message *xmpp.Message, username string) error { +type mySQLOffline struct { + *mySQLStorage + pool *pool.BufferPool +} + +func newOffline(db *sql.DB) *mySQLOffline { + return &mySQLOffline{ + mySQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +func (s *mySQLOffline) InsertOfflineMessage(ctx context.Context, message *xmpp.Message, username string) error { q := sq.Insert("offline_messages"). Columns("username", "data", "created_at"). Values(username, message.String(), nowExpr) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -// CountOfflineMessages returns current length of user's offline queue. -func (s *Storage) CountOfflineMessages(username string) (int, error) { +func (s *mySQLOffline) CountOfflineMessages(ctx context.Context, username string) (int, error) { q := sq.Select("COUNT(*)"). From("offline_messages"). Where(sq.Eq{"username": username}). OrderBy("created_at") var count int - err := q.RunWith(s.db).Scan(&count) + err := q.RunWith(s.db).QueryRowContext(ctx).Scan(&count) switch err { case nil: return count, nil @@ -38,18 +51,17 @@ func (s *Storage) CountOfflineMessages(username string) (int, error) { } } -// FetchOfflineMessages retrieves from storage current user offline queue. -func (s *Storage) FetchOfflineMessages(username string) ([]xmpp.Message, error) { +func (s *mySQLOffline) FetchOfflineMessages(ctx context.Context, username string) ([]xmpp.Message, error) { q := sq.Select("data"). From("offline_messages"). Where(sq.Eq{"username": username}). OrderBy("created_at") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() buf := s.pool.Get() defer s.pool.Put(buf) @@ -69,24 +81,24 @@ func (s *Storage) FetchOfflineMessages(username string) ([]xmpp.Message, error) if err != nil { return nil, err } - elems := rootEl.Elements().All() + elements := rootEl.Elements().All() - var msgs []xmpp.Message - for _, el := range elems { + messages := make([]xmpp.Message, len(elements)) + for i, el := range elements { fromJID, _ := jid.NewWithString(el.From(), true) toJID, _ := jid.NewWithString(el.To(), true) + msg, err := xmpp.NewMessageFromElement(el, fromJID, toJID) if err != nil { return nil, err } - msgs = append(msgs, *msg) + messages[i] = *msg } - return msgs, nil + return messages, nil } -// DeleteOfflineMessages clears a user offline queue. -func (s *Storage) DeleteOfflineMessages(username string) error { +func (s *mySQLOffline) DeleteOfflineMessages(ctx context.Context, username string) error { q := sq.Delete("offline_messages").Where(sq.Eq{"username": username}) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } diff --git a/storage/mysql/offline_test.go b/storage/mysql/offline_test.go index 5b4d7ea74..156128828 100644 --- a/storage/mysql/offline_test.go +++ b/storage/mysql/offline_test.go @@ -6,9 +6,11 @@ package mysql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/pborman/uuid" @@ -23,21 +25,21 @@ func TestMySQLStorageInsertOfflineMessages(t *testing.T) { m, _ := xmpp.NewMessageFromElement(message, j, j) messageXML := m.String() - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectExec("INSERT INTO offline_messages (.+)"). WithArgs("ortuman", messageXML). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOfflineMessage(m, "ortuman") + err := s.InsertOfflineMessage(context.Background(), m, "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectExec("INSERT INTO offline_messages (.+)"). WithArgs("ortuman", messageXML). WillReturnError(errMySQLStorage) - err = s.InsertOfflineMessage(m, "ortuman") + err = s.InsertOfflineMessage(context.Background(), m, "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) } @@ -45,30 +47,30 @@ func TestMySQLStorageInsertOfflineMessages(t *testing.T) { func TestMySQLStorageCountOfflineMessages(t *testing.T) { countColums := []string{"count"} - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectQuery("SELECT COUNT(.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(countColums).AddRow(1)) - cnt, _ := s.CountOfflineMessages("ortuman") + cnt, _ := s.CountOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 1, cnt) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT COUNT(.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(countColums)) - cnt, _ = s.CountOfflineMessages("ortuman") + cnt, _ = s.CountOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 0, cnt) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT COUNT(.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnError(errMySQLStorage) - _, err := s.CountOfflineMessages("ortuman") + _, err := s.CountOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } @@ -76,57 +78,65 @@ func TestMySQLStorageCountOfflineMessages(t *testing.T) { func TestMySQLStorageFetchOfflineMessages(t *testing.T) { var offlineMessagesColumns = []string{"data"} - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(offlineMessagesColumns).AddRow("Hi!")) - msgs, _ := s.FetchOfflineMessages("ortuman") + msgs, _ := s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 1, len(msgs)) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(offlineMessagesColumns)) - msgs, _ = s.FetchOfflineMessages("ortuman") + msgs, _ = s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 0, len(msgs)) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(offlineMessagesColumns).AddRow("Hi!")) - _, err := s.FetchOfflineMessages("ortuman") + _, err := s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnError(errMySQLStorage) - _, err = s.FetchOfflineMessages("ortuman") + _, err = s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } func TestMySQLStorageDeleteOfflineMessages(t *testing.T) { - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectExec("DELETE FROM offline_messages (.+)"). WithArgs("ortuman").WillReturnResult(sqlmock.NewResult(0, 1)) - err := s.DeleteOfflineMessages("ortuman") + err := s.DeleteOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectExec("DELETE FROM offline_messages (.+)"). WithArgs("ortuman").WillReturnError(errMySQLStorage) - err = s.DeleteOfflineMessages("ortuman") + err = s.DeleteOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } + +func newOfflineMock() (*mySQLOffline, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLOffline{ + mySQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock +} diff --git a/storage/mysql/presences.go b/storage/mysql/presences.go new file mode 100644 index 000000000..dc528890d --- /dev/null +++ b/storage/mysql/presences.go @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "database/sql" + "encoding/json" + "strings" + + sq "github.com/Masterminds/squirrel" + capsmodel "github.com/ortuman/jackal/model/capabilities" + "github.com/ortuman/jackal/util/pool" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +type mySQLPresences struct { + *mySQLStorage + pool *pool.BufferPool +} + +func newPresences(db *sql.DB) *mySQLPresences { + return &mySQLPresences{ + mySQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +func (s *mySQLPresences) UpsertPresence(ctx context.Context, presence *xmpp.Presence, jid *jid.JID, allocationID string) (inserted bool, err error) { + buf := s.pool.Get() + defer s.pool.Put(buf) + if err := presence.ToXML(buf, true); err != nil { + return false, err + } + var node, ver string + if caps := presence.Capabilities(); caps != nil { + node = caps.Node + ver = caps.Ver + } + rawXML := buf.String() + + q := sq.Insert("presences"). + Columns("username", "domain", "resource", "presence", "node", "ver", "allocation_id", "updated_at", "created_at"). + Values(jid.Node(), jid.Domain(), jid.Resource(), rawXML, node, ver, allocationID, nowExpr, nowExpr). + Suffix("ON DUPLICATE KEY UPDATE presence = ?, node = ?, ver = ?, allocation_id = ?, updated_at = NOW()", rawXML, node, ver, allocationID) + stmRes, err := q.RunWith(s.db).ExecContext(ctx) + if err != nil { + return false, err + } + affectedRows, err := stmRes.RowsAffected() + if err != nil { + return false, err + } + return affectedRows == 1, nil +} + +func (s *mySQLPresences) FetchPresence(ctx context.Context, jid *jid.JID) (*capsmodel.PresenceCaps, error) { + var rawXML, node, ver, featuresJSON string + + q := sq.Select("presence", "c.node", "c.ver", "c.features"). + From("presences AS p, capabilities AS c"). + Where(sq.And{ + sq.Eq{"username": jid.Node()}, + sq.Eq{"domain": jid.Domain()}, + sq.Eq{"resource": jid.Resource()}, + sq.Expr("p.node = c.node"), + sq.Expr("p.ver = c.ver"), + }). + RunWith(s.db) + + err := q.ScanContext(ctx, &rawXML, &node, &ver, &featuresJSON) + switch err { + case nil: + return scanPresenceAndCapabilties(rawXML, node, ver, featuresJSON) + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func (s *mySQLPresences) FetchPresencesMatchingJID(ctx context.Context, jid *jid.JID) ([]capsmodel.PresenceCaps, error) { + var preds sq.And + if len(jid.Node()) > 0 { + preds = append(preds, sq.Eq{"username": jid.Node()}) + } + if len(jid.Domain()) > 0 { + preds = append(preds, sq.Eq{"domain": jid.Domain()}) + } + if len(jid.Resource()) > 0 { + preds = append(preds, sq.Eq{"resource": jid.Resource()}) + } + preds = append(preds, sq.Expr("p.node = c.node")) + preds = append(preds, sq.Expr("p.ver = c.ver")) + + q := sq.Select("presence", "c.node", "c.ver", "c.features"). + From("presences AS p, capabilities AS c"). + Where(preds). + RunWith(s.db) + + rows, err := q.QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var res []capsmodel.PresenceCaps + for rows.Next() { + var rawXML, node, ver, featuresJSON string + + if err := rows.Scan(&rawXML, &node, &ver, &featuresJSON); err != nil { + return nil, err + } + presenceCaps, err := scanPresenceAndCapabilties(rawXML, node, ver, featuresJSON) + if err != nil { + return nil, err + } + res = append(res, *presenceCaps) + } + return res, nil +} + +func (s *mySQLPresences) DeletePresence(ctx context.Context, jid *jid.JID) error { + _, err := sq.Delete("presences"). + Where(sq.And{ + sq.Eq{"username": jid.Node()}, + sq.Eq{"domain": jid.Domain()}, + sq.Eq{"resource": jid.Resource()}, + }). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *mySQLPresences) DeleteAllocationPresences(ctx context.Context, allocationID string) error { + _, err := sq.Delete("presences"). + Where(sq.Eq{"allocation_id": allocationID}). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *mySQLPresences) ClearPresences(ctx context.Context) error { + _, err := sq.Delete("presences").RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *mySQLPresences) UpsertCapabilities(ctx context.Context, caps *capsmodel.Capabilities) error { + b, err := json.Marshal(caps.Features) + if err != nil { + return err + } + _, err = sq.Insert("capabilities"). + Columns("node", "ver", "features", "updated_at", "created_at"). + Values(caps.Node, caps.Ver, b, nowExpr, nowExpr). + Suffix("ON DUPLICATE KEY UPDATE features = ?, updated_at = NOW()", b). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *mySQLPresences) FetchCapabilities(ctx context.Context, node, ver string) (*capsmodel.Capabilities, error) { + var b string + err := sq.Select("features").From("capabilities"). + Where(sq.And{sq.Eq{"node": node}, sq.Eq{"ver": ver}}). + RunWith(s.db).QueryRowContext(ctx).Scan(&b) + switch err { + case nil: + var caps capsmodel.Capabilities + if err := json.NewDecoder(strings.NewReader(b)).Decode(&caps.Features); err != nil { + return nil, err + } + return &caps, nil + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func scanPresenceAndCapabilties(rawXML, node, ver, featuresJSON string) (*capsmodel.PresenceCaps, error) { + parser := xmpp.NewParser(strings.NewReader(rawXML), xmpp.DefaultMode, 0) + elem, err := parser.ParseElement() + if err != nil { + return nil, err + } + fromJID, _ := jid.NewWithString(elem.From(), true) + toJID, _ := jid.NewWithString(elem.To(), true) + + presence, err := xmpp.NewPresenceFromElement(elem, fromJID, toJID) + if err != nil { + return nil, err + } + var res capsmodel.PresenceCaps + + res.Presence = presence + if len(featuresJSON) > 0 { + res.Caps = &capsmodel.Capabilities{ + Node: node, + Ver: ver, + } + + if err := json.NewDecoder(strings.NewReader(featuresJSON)).Decode(&res.Caps.Features); err != nil { + return nil, err + } + } + return &res, nil +} diff --git a/storage/mysql/presences_test.go b/storage/mysql/presences_test.go new file mode 100644 index 000000000..9d36719d7 --- /dev/null +++ b/storage/mysql/presences_test.go @@ -0,0 +1,181 @@ +package mysql + +import ( + "context" + "encoding/json" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + capsmodel "github.com/ortuman/jackal/model/capabilities" + "github.com/ortuman/jackal/util/pool" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestMySQLPresences_UpsertPresence(t *testing.T) { + + s, mock := newPresencesMock() + mock.ExpectExec("INSERT INTO presences (.+) VALUES (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs("ortuman", "jackal.im", "yard", ``, "", "", "alloc-1234", ``, "", "", "alloc-1234"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + inserted, err := s.UpsertPresence(context.Background(), xmpp.NewPresence(j, j.ToBareJID(), xmpp.AvailableType), j, "alloc-1234") + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.True(t, inserted) +} + +func TestMySQLPresences_FetchPresence(t *testing.T) { + var columns = []string{"presence", "c.node", "c.ver", "c.features"} + + s, mock := newPresencesMock() + mock.ExpectQuery("SELECT presence, c.node, c.ver, c.features FROM presences AS p, capabilities AS c WHERE \\(username = \\? AND domain = \\? AND resource = \\? AND p.node = c.node AND p.ver = c.ver\\)"). + WithArgs("ortuman", "jackal.im", "yard"). + WillReturnRows(sqlmock.NewRows(columns). + AddRow("", "http://jackal.im", "v1234", `["urn:xmpp:ping"]`)) + + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + presenceCaps, err := s.FetchPresence(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.NotNil(t, presenceCaps) + + require.Equal(t, "http://jackal.im", presenceCaps.Caps.Node) + require.Equal(t, "v1234", presenceCaps.Caps.Ver) + require.Len(t, presenceCaps.Caps.Features, 1) + require.Equal(t, "urn:xmpp:ping", presenceCaps.Caps.Features[0]) +} + +func TestMySQLPresences_FetchPresencesMatchingJID(t *testing.T) { + var columns = []string{"presence", "c.node", "c.ver", "c.features"} + + s, mock := newPresencesMock() + mock.ExpectQuery("SELECT presence, c.node, c.ver, c.features FROM presences AS p, capabilities AS c WHERE \\(username = \\? AND domain = \\? AND resource = \\? AND p.node = c.node AND p.ver = c.ver\\)"). + WithArgs("ortuman", "jackal.im", "yard"). + WillReturnRows(sqlmock.NewRows(columns). + AddRow("", "http://jackal.im", "v1234", `["urn:xmpp:ping"]`). + AddRow("", "http://jackal.im", "v1234", `["urn:xmpp:ping"]`), + ) + + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + presenceCaps, err := s.FetchPresencesMatchingJID(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.NotNil(t, presenceCaps) + + require.Equal(t, "http://jackal.im", presenceCaps[0].Caps.Node) + require.Equal(t, "v1234", presenceCaps[0].Caps.Ver) + require.Len(t, presenceCaps[0].Caps.Features, 1) + require.Equal(t, "urn:xmpp:ping", presenceCaps[0].Caps.Features[0]) +} + +func TestMySQLPresences_DeletePresence(t *testing.T) { + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + + s, mock := newPresencesMock() + mock.ExpectExec("DELETE FROM presences WHERE \\(username = \\? AND domain = \\? AND resource = \\?\\)"). + WithArgs(j.Node(), j.Domain(), j.Resource()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := s.DeletePresence(context.Background(), j) + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestMySQLPresences_DeleteAllocationPresence(t *testing.T) { + s, mock := newPresencesMock() + mock.ExpectExec("DELETE FROM presences WHERE allocation_id = ?"). + WithArgs("alloc-1234"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := s.DeleteAllocationPresences(context.Background(), "alloc-1234") + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestMySQLPresences_ClearPresences(t *testing.T) { + s, mock := newPresencesMock() + mock.ExpectExec("DELETE FROM presences"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := s.ClearPresences(context.Background()) + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestMySQLPresences_UpsertCapabilities(t *testing.T) { + features := []string{"jabber:iq:last"} + + b, _ := json.Marshal(&features) + + s, mock := newPresencesMock() + mock.ExpectExec("INSERT INTO capabilities (.+) VALUES (.+) ON DUPLICATE KEY UPDATE features = \\?, updated_at = NOW\\(\\)"). + WithArgs("n1", "1234A", b, b). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := s.UpsertCapabilities(context.Background(), &capsmodel.Capabilities{Node: "n1", Ver: "1234A", Features: features}) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + + // error case + s, mock = newPresencesMock() + mock.ExpectExec("INSERT INTO capabilities (.+) VALUES (.+) ON DUPLICATE KEY UPDATE features = \\?, updated_at = NOW\\(\\)"). + WithArgs("n1", "1234A", b, b). + WillReturnError(errMySQLStorage) + + err = s.UpsertCapabilities(context.Background(), &capsmodel.Capabilities{Node: "n1", Ver: "1234A", Features: features}) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestMySQLPresences_FetchCapabilities(t *testing.T) { + s, mock := newPresencesMock() + rows := sqlmock.NewRows([]string{"features"}) + rows.AddRow(`["jabber:iq:last"]`) + + mock.ExpectQuery("SELECT features FROM capabilities WHERE \\(node = . AND ver = .\\)"). + WithArgs("n1", "1234A"). + WillReturnRows(rows) + + caps, err := s.FetchCapabilities(context.Background(), "n1", "1234A") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 1, len(caps.Features)) + require.Equal(t, "jabber:iq:last", caps.Features[0]) + + // error case + s, mock = newPresencesMock() + mock.ExpectQuery("SELECT features FROM capabilities WHERE \\(node = . AND ver = .\\)"). + WithArgs("n1", "1234A"). + WillReturnError(errMySQLStorage) + + caps, err = s.FetchCapabilities(context.Background(), "n1", "1234A") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Nil(t, caps) +} + +func newPresencesMock() (*mySQLPresences, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLPresences{ + mySQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock +} diff --git a/storage/mysql/private.go b/storage/mysql/private.go index f801b4e13..a08bf527e 100644 --- a/storage/mysql/private.go +++ b/storage/mysql/private.go @@ -6,19 +6,33 @@ package mysql import ( + "context" "database/sql" sq "github.com/Masterminds/squirrel" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" ) -// InsertOrUpdatePrivateXML inserts a new private element into storage, -// or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace string, username string) error { +type mySQLPrivate struct { + *mySQLStorage + pool *pool.BufferPool +} + +func newPrivate(db *sql.DB) *mySQLPrivate { + return &mySQLPrivate{ + mySQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +func (s *mySQLPrivate) UpsertPrivateXML(ctx context.Context, privateXML []xmpp.XElement, namespace string, username string) error { buf := s.pool.Get() defer s.pool.Put(buf) for _, elem := range privateXML { - elem.ToXML(buf, true) + if err := elem.ToXML(buf, true); err != nil { + return err + } } rawXML := buf.String() @@ -27,18 +41,17 @@ func (s *Storage) InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace Values(username, namespace, rawXML, nowExpr, nowExpr). Suffix("ON DUPLICATE KEY UPDATE data = ?, updated_at = NOW()", rawXML) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -// FetchPrivateXML retrieves from storage a private element. -func (s *Storage) FetchPrivateXML(namespace string, username string) ([]xmpp.XElement, error) { +func (s *mySQLPrivate) FetchPrivateXML(ctx context.Context, namespace string, username string) ([]xmpp.XElement, error) { q := sq.Select("data"). From("private_storage"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"namespace": namespace}}) var privateXML string - err := q.RunWith(s.db).QueryRow().Scan(&privateXML) + err := q.RunWith(s.db).QueryRowContext(ctx).Scan(&privateXML) switch err { case nil: buf := s.pool.Get() diff --git a/storage/mysql/private_test.go b/storage/mysql/private_test.go index c918572c8..26fec9134 100644 --- a/storage/mysql/private_test.go +++ b/storage/mysql/private_test.go @@ -6,9 +6,11 @@ package mysql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/stretchr/testify/require" ) @@ -17,21 +19,21 @@ func TestMySQLStorageInsertPrivateXML(t *testing.T) { private := xmpp.NewElementNamespace("exodus", "exodus:ns") rawXML := private.String() - s, mock := NewMock() + s, mock := newPrivateMock() mock.ExpectExec("INSERT INTO private_storage (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs("ortuman", "exodus:ns", rawXML, rawXML). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdatePrivateXML([]xmpp.XElement{private}, "exodus:ns", "ortuman") + err := s.UpsertPrivateXML(context.Background(), []xmpp.XElement{private}, "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectExec("INSERT INTO private_storage (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs("ortuman", "exodus:ns", rawXML, rawXML). WillReturnError(errMySQLStorage) - err = s.InsertOrUpdatePrivateXML([]xmpp.XElement{private}, "exodus:ns", "ortuman") + err = s.UpsertPrivateXML(context.Background(), []xmpp.XElement{private}, "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } @@ -39,53 +41,61 @@ func TestMySQLStorageInsertPrivateXML(t *testing.T) { func TestMySQLStorageFetchPrivateXML(t *testing.T) { var privateColumns = []string{"data"} - s, mock := NewMock() + s, mock := newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns).AddRow("")) - elems, err := s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err := s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 1, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns).AddRow("")) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) require.Equal(t, 0, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns).AddRow("")) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 0, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns)) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 0, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnError(errMySQLStorage) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) require.Equal(t, 0, len(elems)) } + +func newPrivateMock() (*mySQLPrivate, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLPrivate{ + mySQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock +} diff --git a/storage/mysql/pubsub.go b/storage/mysql/pubsub.go new file mode 100644 index 000000000..332cd1125 --- /dev/null +++ b/storage/mysql/pubsub.go @@ -0,0 +1,524 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "database/sql" + "strings" + + sq "github.com/Masterminds/squirrel" + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + "github.com/ortuman/jackal/xmpp" +) + +type mySQLPubSub struct { + *mySQLStorage +} + +func newPubSub(db *sql.DB) *mySQLPubSub { + return &mySQLPubSub{ + mySQLStorage: newStorage(db), + } +} + +func (s *mySQLPubSub) FetchHosts(ctx context.Context) ([]string, error) { + rows, err := sq.Select("DISTINCT(host)"). + From("pubsub_nodes"). + RunWith(s.db). + QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var hosts []string + for rows.Next() { + var host string + if err := rows.Scan(&host); err != nil { + return nil, err + } + hosts = append(hosts, host) + } + return hosts, nil +} + +func (s *mySQLPubSub) UpsertNode(ctx context.Context, node *pubsubmodel.Node) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + + // if not existing, insert new node + _, err := sq.Insert("pubsub_nodes"). + Columns("host", "name", "updated_at", "created_at"). + Suffix("ON DUPLICATE KEY UPDATE updated_at = NOW()"). + Values(node.Host, node.Name, nowExpr, nowExpr). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // fetch node identifier + var nodeIdentifier string + + err = sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": node.Host}, sq.Eq{"name": node.Name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + if err != nil { + return err + } + // delete previous node options + _, err = sq.Delete("pubsub_node_options"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // insert new option set + optionSetMap, err := node.Options.Map() + if err != nil { + return err + } + for name, value := range optionSetMap { + _, err = sq.Insert("pubsub_node_options"). + Columns("node_id", "name", "value", "updated_at", "created_at"). + Values(nodeIdentifier, name, value, nowExpr, nowExpr). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + return nil + }) +} + +func (s *mySQLPubSub) FetchNode(ctx context.Context, host, name string) (*pubsubmodel.Node, error) { + opts, err := s.fetchPubSubNodeOptions(ctx, host, name) + if err != nil { + return nil, err + } + if opts == nil { + return nil, nil // not found + } + return &pubsubmodel.Node{ + Host: host, + Name: name, + Options: *opts, + }, nil +} + +func (s *mySQLPubSub) FetchNodes(ctx context.Context, host string) ([]pubsubmodel.Node, error) { + rows, err := sq.Select("name"). + From("pubsub_nodes"). + Where(sq.Eq{"host": host}). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var nodes []pubsubmodel.Node + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + var node = pubsubmodel.Node{Host: host, Name: name} + opts, err := s.fetchPubSubNodeOptions(ctx, host, name) + if err != nil { + return nil, err + } + if opts != nil { + node.Options = *opts + } + nodes = append(nodes, node) + } + return nodes, nil +} + +func (s *mySQLPubSub) FetchSubscribedNodes(ctx context.Context, jid string) ([]pubsubmodel.Node, error) { + rows, err := sq.Select("host", "name"). + From("pubsub_nodes"). + Where(sq.Expr("id IN (SELECT DISTINCT(node_id) FROM pubsub_subscriptions WHERE jid = ? AND subscription = ?)", jid, pubsubmodel.Subscribed)). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var nodes []pubsubmodel.Node + for rows.Next() { + var host, name string + if err := rows.Scan(&host, &name); err != nil { + return nil, err + } + var node = pubsubmodel.Node{Host: host, Name: name} + opts, err := s.fetchPubSubNodeOptions(ctx, host, name) + if err != nil { + return nil, err + } + if opts != nil { + node.Options = *opts + } + nodes = append(nodes, node) + } + return nodes, nil +} + +func (s *mySQLPubSub) DeleteNode(ctx context.Context, host, name string) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + // delete node + _, err = sq.Delete("pubsub_nodes"). + Where(sq.Eq{"id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete options + _, err = sq.Delete("pubsub_node_options"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete items + _, err = sq.Delete("pubsub_items"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete affiliations + _, err = sq.Delete("pubsub_affiliations"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete subscriptions + _, err = sq.Delete("pubsub_subscriptions"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + return err + }) +} + +func (s *mySQLPubSub) UpsertNodeItem(ctx context.Context, item *pubsubmodel.Item, host, name string, maxNodeItems int) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + + // upsert new item + rawPayload := item.Payload.String() + + _, err = sq.Insert("pubsub_items"). + Columns("node_id", "item_id", "payload", "publisher", "updated_at", "created_at"). + Values(nodeIdentifier, item.ID, rawPayload, item.Publisher, nowExpr, nowExpr). + Suffix("ON DUPLICATE KEY UPDATE payload = ?, publisher = ?, updated_at = NOW()", rawPayload, item.Publisher). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // fetch valid identifiers + rows, err := sq.Select("item_id"). + From("pubsub_items"). + Where(sq.Eq{"node_id": nodeIdentifier}). + OrderBy("created_at DESC"). + Limit(uint64(maxNodeItems)).RunWith(tx).QueryContext(ctx) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var validIdentifiers []string + for rows.Next() { + var identifier string + if err := rows.Scan(&identifier); err != nil { + return err + } + validIdentifiers = append(validIdentifiers, identifier) + } + // delete older items + _, err = sq.Delete("pubsub_items"). + Where(sq.And{sq.Eq{"node_id": nodeIdentifier}, sq.NotEq{"item_id": validIdentifiers}}). + RunWith(tx). + ExecContext(ctx) + return err + }) +} + +func (s *mySQLPubSub) FetchNodeItems(ctx context.Context, host, name string) ([]pubsubmodel.Item, error) { + rows, err := sq.Select("item_id", "publisher", "payload"). + From("pubsub_items"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", host, name). + OrderBy("created_at"). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeItems(rows) +} + +func (s *mySQLPubSub) FetchNodeItemsWithIDs(ctx context.Context, host, name string, identifiers []string) ([]pubsubmodel.Item, error) { + rows, err := sq.Select("item_id", "publisher", "payload"). + From("pubsub_items"). + Where(sq.And{sq.Expr("node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", host, name), sq.Eq{"id": identifiers}}). + OrderBy("created_at"). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeItems(rows) +} + +func (s *mySQLPubSub) FetchNodeLastItem(ctx context.Context, host, name string) (*pubsubmodel.Item, error) { + row := sq.Select("item_id", "publisher", "payload"). + From("pubsub_items"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", host, name). + OrderBy("created_at DESC"). + Limit(1). + RunWith(s.db).QueryRowContext(ctx) + + item, err := scanPubSubNodeItem(row) + switch err { + case nil: + return item, nil + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func (s *mySQLPubSub) UpsertNodeAffiliation(ctx context.Context, affiliation *pubsubmodel.Affiliation, host, name string) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + + // insert affiliation + _, err = sq.Insert("pubsub_affiliations"). + Columns("node_id", "jid", "affiliation", "updated_at", "created_at"). + Values(nodeIdentifier, affiliation.JID, affiliation.Affiliation, nowExpr, nowExpr). + Suffix("ON DUPLICATE KEY UPDATE affiliation = ?, updated_at = NOW()", affiliation.Affiliation). + RunWith(tx).ExecContext(ctx) + return err + }) +} + +func (s *mySQLPubSub) FetchNodeAffiliation(ctx context.Context, host, name, jid string) (*pubsubmodel.Affiliation, error) { + var aff pubsubmodel.Affiliation + + row := sq.Select("jid", "affiliation"). + From("pubsub_affiliations"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?) AND jid = ?", host, name, jid). + RunWith(s.db).QueryRowContext(ctx) + err := row.Scan(&aff.JID, &aff.Affiliation) + switch err { + case nil: + return &aff, nil + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func (s *mySQLPubSub) FetchNodeAffiliations(ctx context.Context, host, name string) ([]pubsubmodel.Affiliation, error) { + rows, err := sq.Select("jid", "affiliation"). + From("pubsub_affiliations"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", host, name). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeAffiliations(rows) +} + +func (s *mySQLPubSub) DeleteNodeAffiliation(ctx context.Context, jid, host, name string) error { + _, err := sq.Delete("pubsub_affiliations"). + Where("jid = ? AND node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", jid, host, name). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *mySQLPubSub) UpsertNodeSubscription(ctx context.Context, subscription *pubsubmodel.Subscription, host, name string) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + + // upsert subscription + _, err = sq.Insert("pubsub_subscriptions"). + Columns("node_id", "subid", "jid", "subscription", "updated_at", "created_at"). + Values(nodeIdentifier, subscription.SubID, subscription.JID, subscription.Subscription, nowExpr, nowExpr). + Suffix("ON DUPLICATE KEY UPDATE subid = ?, subscription = ?, updated_at = NOW()", subscription.SubID, subscription.Subscription). + RunWith(tx).ExecContext(ctx) + return err + }) +} + +func (s *mySQLPubSub) FetchNodeSubscriptions(ctx context.Context, host, name string) ([]pubsubmodel.Subscription, error) { + rows, err := sq.Select("subid", "jid", "subscription"). + From("pubsub_subscriptions"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", host, name). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeSubscriptions(rows) +} + +func (s *mySQLPubSub) DeleteNodeSubscription(ctx context.Context, jid, host, name string) error { + _, err := sq.Delete("pubsub_subscriptions"). + Where("jid = ? AND node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", jid, host, name). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *mySQLPubSub) fetchPubSubNodeOptions(ctx context.Context, host, name string) (*pubsubmodel.Options, error) { + rows, err := sq.Select("name", "value"). + From("pubsub_node_options"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = ? AND name = ?)", host, name). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var optMap = make(map[string]string) + for rows.Next() { + var opt, value string + if err := rows.Scan(&opt, &value); err != nil { + return nil, err + } + optMap[opt] = value + } + if len(optMap) == 0 { + return nil, nil // node does not exist + } + opts, err := pubsubmodel.NewOptionsFromMap(optMap) + if err != nil { + return nil, err + } + return opts, nil +} + +func scanPubSubNodeAffiliations(scanner rowsScanner) ([]pubsubmodel.Affiliation, error) { + var affiliations []pubsubmodel.Affiliation + + for scanner.Next() { + var affiliation pubsubmodel.Affiliation + if err := scanner.Scan(&affiliation.JID, &affiliation.Affiliation); err != nil { + return nil, err + } + affiliations = append(affiliations, affiliation) + } + return affiliations, nil +} + +func scanPubSubNodeSubscriptions(scanner rowsScanner) ([]pubsubmodel.Subscription, error) { + var subscriptions []pubsubmodel.Subscription + + for scanner.Next() { + var subscription pubsubmodel.Subscription + if err := scanner.Scan(&subscription.SubID, &subscription.JID, &subscription.Subscription); err != nil { + return nil, err + } + subscriptions = append(subscriptions, subscription) + } + return subscriptions, nil +} + +func scanPubSubNodeItems(scanner rowsScanner) ([]pubsubmodel.Item, error) { + var items []pubsubmodel.Item + + for scanner.Next() { + item, err := scanPubSubNodeItem(scanner) + if err != nil { + return nil, err + } + items = append(items, *item) + } + return items, nil +} + +func scanPubSubNodeItem(scanner rowScanner) (*pubsubmodel.Item, error) { + var payload string + var item pubsubmodel.Item + var err error + + if err = scanner.Scan(&item.ID, &item.Publisher, &payload); err != nil { + return nil, err + } + parser := xmpp.NewParser(strings.NewReader(payload), xmpp.DefaultMode, 0) + item.Payload, err = parser.ParseElement() + if err != nil { + return nil, err + } + return &item, nil +} diff --git a/storage/mysql/pubsub_test.go b/storage/mysql/pubsub_test.go new file mode 100644 index 000000000..7e36097d3 --- /dev/null +++ b/storage/mysql/pubsub_test.go @@ -0,0 +1,478 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/google/uuid" + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + "github.com/ortuman/jackal/xmpp" + "github.com/stretchr/testify/require" +) + +func TestMySQLFetchPubSubHosts(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"host"}) + rows.AddRow("ortuman@jackal.im") + rows.AddRow("noelia@jackal.im") + + mock.ExpectQuery("SELECT DISTINCT\\(host\\) FROM pubsub_nodes"). + WillReturnRows(rows) + + hosts, err := s.FetchHosts(context.Background()) + require.Nil(t, err) + require.NotNil(t, hosts) + require.Equal(t, "ortuman@jackal.im", hosts[0]) + require.Equal(t, "noelia@jackal.im", hosts[1]) + + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT DISTINCT\\(host\\) FROM pubsub_nodes"). + WillReturnError(errMySQLStorage) + + hosts, err = s.FetchHosts(context.Background()) + require.Nil(t, hosts) + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestMySQLUpsertPubSubNode(t *testing.T) { + s, mock := newPubSubMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO pubsub_nodes (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs("host", "name"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("host", "name"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("DELETE FROM pubsub_node_options WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + opts := pubsubmodel.Options{} + + optMap, _ := opts.Map() + for i := 0; i < len(optMap); i++ { + mock.ExpectExec("INSERT INTO pubsub_node_options (.+)"). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 1)) + } + mock.ExpectCommit() + + node := pubsubmodel.Node{Host: "host", Name: "name", Options: opts} + err := s.UpsertNode(context.Background(), &node) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) +} + +func TestMySQLFetchPubSubNodes(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"name"}) + rows.AddRow("princely_musings_1") + rows.AddRow("princely_musings_2") + + mock.ExpectQuery("SELECT name FROM pubsub_nodes WHERE host = (.+)"). + WithArgs("ortuman@jackal.im"). + WillReturnRows(rows) + + var cols = []string{"name", "value"} + + rows = sqlmock.NewRows(cols) + rows.AddRow("pubsub#access_model", "presence") + rows.AddRow("pubsub#publish_model", "publishers") + rows.AddRow("pubsub#send_last_published_item", "on_sub_and_presence") + + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_1"). + WillReturnRows(rows) + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_2"). + WillReturnRows(rows) + + nodes, err := s.FetchNodes(context.Background(), "ortuman@jackal.im") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.NotNil(t, nodes) + require.Len(t, nodes, 2) + require.Equal(t, "princely_musings_1", nodes[0].Name) + require.Equal(t, "princely_musings_2", nodes[1].Name) +} + +func TestMySQLFetchPubSubSubscribedNodes(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"host", "name"}) + rows.AddRow("ortuman@jackal.im", "princely_musings_1") + rows.AddRow("ortuman@jackal.im", "princely_musings_2") + + mock.ExpectQuery("SELECT host, name FROM pubsub_nodes WHERE id IN (.+)"). + WithArgs("ortuman@jackal.im", pubsubmodel.Subscribed). + WillReturnRows(rows) + + var cols = []string{"name", "value"} + + rows = sqlmock.NewRows(cols) + rows.AddRow("pubsub#access_model", "presence") + rows.AddRow("pubsub#publish_model", "publishers") + rows.AddRow("pubsub#send_last_published_item", "on_sub_and_presence") + + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_1"). + WillReturnRows(rows) + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_2"). + WillReturnRows(rows) + + nodes, err := s.FetchSubscribedNodes(context.Background(), "ortuman@jackal.im") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.NotNil(t, nodes) + require.Len(t, nodes, 2) + require.Equal(t, "princely_musings_1", nodes[0].Name) + require.Equal(t, "princely_musings_2", nodes[1].Name) +} + +func TestMySQLFetchPubSubNode(t *testing.T) { + var cols = []string{"name", "value"} + + s, mock := newPubSubMock() + rows := sqlmock.NewRows(cols) + rows.AddRow("pubsub#access_model", "presence") + rows.AddRow("pubsub#publish_model", "publishers") + rows.AddRow("pubsub#send_last_published_item", "on_sub_and_presence") + + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + node, err := s.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.NotNil(t, node) + require.Equal(t, node.Options.AccessModel, pubsubmodel.Presence) + require.Equal(t, node.Options.SendLastPublishedItem, pubsubmodel.OnSubAndPresence) + + // error case + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errMySQLStorage) + + _, err = s.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestMySQLDeletePubSubNode(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("DELETE FROM pubsub_nodes WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_node_options WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_items WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_affiliations WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.DeleteNode(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestMySQLUpsertPubSubNodeItem(t *testing.T) { + payload := xmpp.NewIQType(uuid.New().String(), xmpp.GetType) + + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("INSERT INTO pubsub_items (.+) ON DUPLICATE KEY UPDATE payload = (.+), publisher = (.+), updated_at = NOW()"). + WithArgs("1", "abc1234", payload.String(), "ortuman@jackal.im", payload.String(), "ortuman@jackal.im"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + mock.ExpectQuery("SELECT item_id FROM pubsub_items WHERE node_id = \\? ORDER BY created_at DESC LIMIT 1"). + WithArgs("1"). + WillReturnRows(sqlmock.NewRows([]string{"item_id"}).AddRow("1").AddRow("2")) + + mock.ExpectExec("DELETE FROM pubsub_items WHERE \\(node_id = \\? AND item_id NOT IN \\(.+\\)\\)"). + WithArgs("1", "1", "2"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.UpsertNodeItem(context.Background(), &pubsubmodel.Item{ + ID: "abc1234", + Publisher: "ortuman@jackal.im", + Payload: payload, + }, "ortuman@jackal.im", "princely_musings", 1) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) +} + +func TestMySQLFetchPubSubNodeItems(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"item_id", "publisher", "payload"}) + rows.AddRow("1234", "ortuman@jackal.im", "") + rows.AddRow("5678", "noelia@jackal.im", "") + + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE node_id = (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + items, err := s.FetchNodeItems(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(items)) + require.Equal(t, "1234", items[0].ID) + require.Equal(t, "5678", items[1].ID) + + // error case + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE node_id = (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errMySQLStorage) + + _, err = s.FetchNodeItems(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestMySQLFetchPubSubNodeItemsWithID(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"item_id", "publisher", "payload"}) + rows.AddRow("1234", "ortuman@jackal.im", "") + rows.AddRow("5678", "noelia@jackal.im", "") + + identifiers := []string{"1234", "5678"} + + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE (.+ IN (.+)) ORDER BY created_at"). + WithArgs("ortuman@jackal.im", "princely_musings", "1234", "5678"). + WillReturnRows(rows) + + items, err := s.FetchNodeItemsWithIDs(context.Background(), "ortuman@jackal.im", "princely_musings", identifiers) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(items)) + require.Equal(t, "1234", items[0].ID) + require.Equal(t, "5678", items[1].ID) + + // error case + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE (.+ IN (.+)) ORDER BY created_at"). + WithArgs("ortuman@jackal.im", "princely_musings", "1234", "5678"). + WillReturnError(errMySQLStorage) + + _, err = s.FetchNodeItemsWithIDs(context.Background(), "ortuman@jackal.im", "princely_musings", identifiers) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestMySQLUpsertPubSubNodeAffiliation(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("INSERT INTO pubsub_affiliations (.+) VALUES (.+) ON DUPLICATE KEY UPDATE affiliation = (.+), updated_at = (.+)"). + WithArgs("1", "ortuman@jackal.im", "owner", "owner"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: "owner", + }, "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) +} + +func TestMySQLFetchPubSubNodeAffiliations(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"jid", "affiliation"}) + rows.AddRow("ortuman@jackal.im", "owner") + rows.AddRow("noelia@jackal.im", "publisher") + + mock.ExpectQuery("SELECT jid, affiliation FROM pubsub_affiliations WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + affiliations, err := s.FetchNodeAffiliations(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(affiliations)) + + // error case + mock.ExpectQuery("SELECT jid, affiliation FROM pubsub_affiliations WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errMySQLStorage) + + affiliations, err = s.FetchNodeAffiliations(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestPgSQLDeletePubSubNodeAffiliation(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectExec("DELETE FROM pubsub_affiliations WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := s.DeleteNodeAffiliation(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + + // error case + s, mock = newPubSubMock() + mock.ExpectExec("DELETE FROM pubsub_affiliations WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnError(errMySQLStorage) + + err = s.DeleteNodeAffiliation(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestMySQLUpsertPubSubNodeSubscription(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("INSERT INTO pubsub_subscriptions (.+) VALUES (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs("1", "1234", "ortuman@jackal.im", "subscribed", "1234", "subscribed"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + SubID: "1234", + JID: "ortuman@jackal.im", + Subscription: "subscribed", + }, "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) +} + +func TestMySQLFetchPubSubNodeSubscriptions(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"subid", "jid", "subscription"}) + rows.AddRow("1234", "ortuman@jackal.im", "subscribed") + rows.AddRow("5678", "noelia@jackal.im", "unsubscribed") + + mock.ExpectQuery("SELECT subid, jid, subscription FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + subscriptions, err := s.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(subscriptions)) + + // error case + mock.ExpectQuery("SELECT subid, jid, subscription FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errMySQLStorage) + + subscriptions, err = s.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func TestMySQLDeletePubSubNodeSubscription(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectExec("DELETE FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := s.DeleteNodeSubscription(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + + // error case + s, mock = newPubSubMock() + mock.ExpectExec("DELETE FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnError(errMySQLStorage) + + err = s.DeleteNodeSubscription(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func newPubSubMock() (*mySQLPubSub, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLPubSub{ + mySQLStorage: s, + }, sqlMock +} diff --git a/storage/mysql/room.go b/storage/mysql/room.go new file mode 100644 index 000000000..5ddfd7965 --- /dev/null +++ b/storage/mysql/room.go @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "database/sql" + + sq "github.com/Masterminds/squirrel" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +type mySQLRoom struct { + *mySQLStorage +} + +func newRoom(db *sql.DB) *mySQLRoom { + return &mySQLRoom{ + mySQLStorage: newStorage(db), + } +} + +func (r *mySQLRoom) UpsertRoom(ctx context.Context, room *mucmodel.Room) error { + return r.inTransaction(ctx, func(tx *sql.Tx) error { + // rooms table + columns := []string{"room_jid", "name", "description", "subject", "language", "locked", + "occupants_online"} + values := []interface{}{room.RoomJID.String(), room.Name, room.Desc, room.Subject, + room.Language, room.Locked, room.GetOccupantsOnlineCount()} + q := sq.Insert("rooms"). + Columns(columns...). + Values(values...). + Suffix("ON DUPLICATE KEY UPDATE name = ?, description = ?, subject = ?, language = ?,"+ + " locked = ?, occupants_online = ?", room.Name, room.Desc, room.Subject, + room.Language, room.Locked, room.GetOccupantsOnlineCount()) + _, err := q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // rooms_config table + rc := room.Config + columns = []string{"room_jid", "public", "persistent", "pwd_protected", "password", "open", + "moderated", "allow_invites", "max_occupants", "allow_subj_change", "non_anonymous", + "can_send_pm", "can_get_member_list"} + values = []interface{}{room.RoomJID.String(), rc.Public, rc.Persistent, rc.PwdProtected, + rc.Password, rc.Open, rc.Moderated, rc.AllowInvites, rc.MaxOccCnt, rc.AllowSubjChange, + rc.NonAnonymous, rc.WhoCanSendPM(), rc.WhoCanGetMemberList()} + q = sq.Insert("rooms_config"). + Columns(columns...). + Values(values...). + Suffix("ON DUPLICATE KEY UPDATE public = ?, persistent = ?, pwd_protected = ?, "+ + "password = ?, open = ?, moderated = ?, allow_invites = ?, max_occupants = ?, "+ + "allow_subj_change = ?, non_anonymous = ?, can_send_pm = ?, can_get_member_list = ?", + rc.Public, rc.Persistent, rc.PwdProtected, rc.Password, rc.Open, rc.Moderated, + rc.AllowInvites, rc.MaxOccCnt, rc.AllowSubjChange, rc.NonAnonymous, rc.WhoCanSendPM(), + rc.WhoCanGetMemberList()) + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // rooms_invites table + columns = []string{"room_jid", "user_jid"} + for _, u := range room.GetAllInvitedUsers() { + values = []interface{}{room.RoomJID.String(), u} + q = sq.Insert("rooms_invites"). + Columns(columns...). + Values(values...). + Suffix("ON DUPLICATE KEY UPDATE user_jid = ?", u) + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + + // rooms_users table + columns = []string{"room_jid", "user_jid", "occupant_jid"} + for _, u := range room.GetAllUserJIDs() { + occJID, _ := room.GetOccupantJID(&u) + values = []interface{}{room.RoomJID.String(), u.String(), occJID.String()} + q = sq.Insert("rooms_users"). + Columns(columns...). + Values(values...). + Suffix("ON DUPLICATE KEY UPDATE occupant_jid = ?", occJID.String()) + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + return nil + }) +} + +func (r *mySQLRoom) FetchRoom(ctx context.Context, roomJID *jid.JID) (*mucmodel.Room, error) { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + room, err := fetchRoomData(ctx, tx, roomJID) + switch err { + case nil: + case sql.ErrNoRows: + _ = tx.Commit() + return nil, nil + default: + _ = tx.Rollback() + return nil, err + } + + err = fetchRoomConfig(ctx, tx, room, roomJID) + switch err { + case nil: + case sql.ErrNoRows: + _ = tx.Commit() + return nil, nil + default: + _ = tx.Rollback() + return nil, err + } + + err = fetchRoomUsers(ctx, tx, room, roomJID) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + err = fetchRoomInvites(ctx, tx, room, roomJID) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return room, nil +} + +func fetchRoomData(ctx context.Context, tx *sql.Tx, roomJID *jid.JID) (*mucmodel.Room, + error) { + room := &mucmodel.Room{} + // fetch room data + q := sq.Select("room_jid", "name", "description", "subject", "language", "locked", + "occupants_online"). + From("rooms"). + Where(sq.Eq{"room_jid": roomJID.String()}) + var onlineCnt int + var roomJIDStr string + err := q.RunWith(tx). + QueryRowContext(ctx). + Scan(&roomJIDStr, &room.Name, &room.Desc, &room.Subject, &room.Language, &room.Locked, + &onlineCnt) + switch err { + case nil: + rJID, err := jid.NewWithString(roomJIDStr, false) + if err != nil { + return nil, err + } + room.RoomJID = rJID + room.SetOccupantsOnlineCount(onlineCnt) + default: + return nil, err + } + return room, nil +} + +func fetchRoomConfig(ctx context.Context, tx *sql.Tx, room *mucmodel.Room, + roomJID *jid.JID) error { + rc := &mucmodel.RoomConfig{} + q := sq.Select("room_jid", "public", "persistent", "pwd_protected", "password", "open", + "moderated", "allow_invites", "max_occupants", "allow_subj_change", "non_anonymous", + "can_send_pm", "can_get_member_list"). + From("rooms_config"). + Where(sq.Eq{"room_jid": roomJID.String()}) + var dummy, sendPM, membList string + err := q.RunWith(tx). + QueryRowContext(ctx). + Scan(&dummy, &rc.Public, &rc.Persistent, &rc.PwdProtected, &rc.Password, &rc.Open, + &rc.Moderated, &rc.AllowInvites, &rc.MaxOccCnt, &rc.AllowSubjChange, &rc.NonAnonymous, + &sendPM, &membList) + switch err { + case nil: + err = rc.SetWhoCanSendPM(sendPM) + if err != nil { + return err + } + err = rc.SetWhoCanGetMemberList(membList) + if err != nil { + return err + } + default: + return err + } + room.Config = rc + return nil +} + +func fetchRoomUsers(ctx context.Context, tx *sql.Tx, room *mucmodel.Room, + roomJID *jid.JID) error { + res, err := sq.Select("room_jid", "user_jid", "occupant_jid"). + From("rooms_users"). + Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).QueryContext(ctx) + if err != nil { + return err + } + for res.Next() { + var dummy, uJIDStr, oJIDStr string + if err := res.Scan(&dummy, &uJIDStr, &oJIDStr); err != nil { + return err + } + uJID, err := jid.NewWithString(uJIDStr, false) + if err != nil { + return err + } + oJID, err := jid.NewWithString(oJIDStr, false) + if err != nil { + return err + } + err = room.MapUserToOccupantJID(uJID, oJID) + if err != nil { + return err + } + } + return nil +} + +func fetchRoomInvites(ctx context.Context, tx *sql.Tx, room *mucmodel.Room, + roomJID *jid.JID) error { + resInv, err := sq.Select("room_jid", "user_jid"). + From("rooms_invites"). + Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).QueryContext(ctx) + if err != nil { + return err + } + for resInv.Next() { + var dummy, uJIDStr string + if err := resInv.Scan(&dummy, &uJIDStr); err != nil { + return err + } + uJID, err := jid.NewWithString(uJIDStr, false) + if err != nil { + return err + } + err = room.InviteUser(uJID) + if err != nil { + return err + } + } + return nil +} + +func (r *mySQLRoom) DeleteRoom(ctx context.Context, roomJID *jid.JID) error { + return r.inTransaction(ctx, func(tx *sql.Tx) error { + _, err := sq.Delete("rooms").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("rooms_config").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("rooms_users").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("rooms_invites").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + return nil + }) +} + +func (r *mySQLRoom) RoomExists(ctx context.Context, roomJID *jid.JID) (bool, error) { + q := sq.Select("COUNT(*)"). + From("rooms"). + Where(sq.Eq{"room_jid": roomJID.String()}) + + var count int + err := q.RunWith(r.db).QueryRowContext(ctx).Scan(&count) + switch err { + case nil: + return count > 0, nil + default: + return false, err + } +} diff --git a/storage/mysql/room_test.go b/storage/mysql/room_test.go new file mode 100644 index 000000000..453896b16 --- /dev/null +++ b/storage/mysql/room_test.go @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMySQLStorageInsertRoom(t *testing.T) { + room := getTestRoom() + s, mock := newRoomMock() + rc := room.Config + userJID := room.GetAllUserJIDs()[0] + occJID, _ := room.GetOccupantJID(&userJID) + invitedUser := room.GetAllInvitedUsers()[0] + + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO rooms (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(room.RoomJID.String(), room.Name, room.Desc, room.Subject, room.Language, + room.Locked, room.GetOccupantsOnlineCount(), room.Name, room.Desc, room.Subject, + room.Language, room.Locked, room.GetOccupantsOnlineCount()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO rooms_config (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(room.RoomJID.String(), rc.Public, rc.Persistent, rc.PwdProtected, + rc.Password, rc.Open, rc.Moderated, rc.AllowInvites, rc.MaxOccCnt, rc.AllowSubjChange, + rc.NonAnonymous, rc.WhoCanSendPM(), rc.WhoCanGetMemberList(), rc.Public, rc.Persistent, + rc.PwdProtected, rc.Password, rc.Open, rc.Moderated, rc.AllowInvites, rc.MaxOccCnt, + rc.AllowSubjChange, rc.NonAnonymous, rc.WhoCanSendPM(), rc.WhoCanGetMemberList()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO rooms_invites (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(room.RoomJID.String(), invitedUser, invitedUser). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO rooms_users (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(room.RoomJID.String(), userJID.String(), occJID.String(), occJID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := s.UpsertRoom(context.Background(), room) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO rooms (.+) ON DUPLICATE KEY UPDATE (.+)"). + WithArgs(room.RoomJID.String(), room.Name, room.Desc, room.Subject, room.Language, + room.Locked, room.GetOccupantsOnlineCount(), room.Name, room.Desc, room.Subject, + room.Language, room.Locked, room.GetOccupantsOnlineCount()). + WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.UpsertRoom(context.Background(), room) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, err, errMocked) +} + +func TestMySQLStorageFetchRoom(t *testing.T) { + room := getTestRoom() + rc := room.Config + s, mock := newRoomMock() + roomColumns := []string{"room_jid", "name", "description", "subject", "language", "locked", + "occupants_online"} + rcColumns := []string{"room_jid", "public", "persistent", "pwd_protected", "password", "open", + "moderated", "allow_invites", "max_occupants", "allow_subj_change", "non_anonymous", + "can_send_pm", "can_get_member_list"} + usersColumns := []string{"room_jid", "user_jid", "occupant_jid"} + invitesColumns := []string{"room_jid", "user_jid"} + + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(roomColumns)) + mock.ExpectCommit() + + r, _ := s.FetchRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, r) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(roomColumns). + AddRow(room.RoomJID.String(), room.Name, room.Desc, room.Subject, room.Language, + room.Locked, room.GetOccupantsOnlineCount())) + mock.ExpectQuery("SELECT (.+) FROM rooms_config (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(rcColumns). + AddRow(room.RoomJID.String(), rc.Public, rc.Persistent, rc.PwdProtected, rc.Password, + rc.Open, rc.Moderated, rc.AllowInvites, rc.MaxOccCnt, rc.AllowSubjChange, + rc.NonAnonymous, rc.WhoCanSendPM(), rc.WhoCanGetMemberList())) + mock.ExpectQuery("SELECT (.+) FROM rooms_users (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(usersColumns). + AddRow(room.RoomJID.String(), room.GetAllUserJIDs()[0].String(), + room.GetAllOccupantJIDs()[0].String())) + mock.ExpectQuery("SELECT (.+) FROM rooms_invites (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(invitesColumns). + AddRow(room.RoomJID.String(), room.GetAllInvitedUsers()[0])) + mock.ExpectCommit() + r, err := s.FetchRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.NotNil(t, r) + assert.EqualValues(t, room, r) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()).WillReturnError(errMocked) + mock.ExpectRollback() + _, err = s.FetchRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestMySQLStorageDeleteRoom(t *testing.T) { + room := getTestRoom() + s, mock := newRoomMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM rooms (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM rooms_config (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM rooms_users (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM rooms_invites (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.DeleteRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM rooms (.+)"). + WithArgs(room.RoomJID.String()).WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.DeleteRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestMySQLStorageRoomExists(t *testing.T) { + room := getTestRoom() + countCols := []string{"count"} + + s, mock := newRoomMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(countCols).AddRow(1)) + + ok, err := s.RoomExists(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.True(t, ok) + + s, mock = newRoomMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnError(errMocked) + _, err = s.RoomExists(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func newRoomMock() (*mySQLRoom, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLRoom{ + mySQLStorage: s, + }, sqlMock +} + +func getTestRoom() *mucmodel.Room { + rc := mucmodel.RoomConfig{ + Public: true, + Persistent: true, + PwdProtected: false, + Open: true, + Moderated: false, + } + j, _ := jid.NewWithString("testroom@conference.jackal.im", true) + + r := &mucmodel.Room{ + Name: "testRoom", + RoomJID: j, + Desc: "Room for Testing", + Config: &rc, + Locked: false, + } + + oJID, _ := jid.NewWithString("testroom@conference.jackal.im/owner", true) + owner, _ := mucmodel.NewOccupant(oJID, oJID.ToBareJID()) + r.AddOccupant(owner) + r.InviteUser(oJID.ToBareJID()) + + return r +} diff --git a/storage/mysql/roster.go b/storage/mysql/roster.go index 58de7e20f..009fded06 100644 --- a/storage/mysql/roster.go +++ b/storage/mysql/roster.go @@ -6,27 +6,36 @@ package mysql import ( + "context" "database/sql" "encoding/json" "strings" sq "github.com/Masterminds/squirrel" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) -// InsertOrUpdateRosterItem inserts a new roster item entity into storage, -// or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Version, error) { +type mySQLRoster struct { + *mySQLStorage +} + +func newRoster(db *sql.DB) *mySQLRoster { + return &mySQLRoster{ + mySQLStorage: newStorage(db), + } +} + +func (s *mySQLRoster) UpsertRosterItem(ctx context.Context, ri *rostermodel.Item) (rostermodel.Version, error) { var ver rostermodel.Version - err := s.inTransaction(func(tx *sql.Tx) error { + err := s.inTransaction(ctx, func(tx *sql.Tx) error { q := sq.Insert("roster_versions"). Columns("username", "created_at", "updated_at"). Values(ri.Username, nowExpr, nowExpr). Suffix("ON DUPLICATE KEY UPDATE ver = ver + 1, updated_at = NOW()") - if _, err := q.RunWith(tx).Exec(); err != nil { + if _, err := q.RunWith(tx).ExecContext(ctx); err != nil { return err } groupsBytes, err := json.Marshal(ri.Groups) @@ -39,14 +48,14 @@ func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve Columns("username", "jid", "name", "subscription", "`groups`", "ask", "ver", "created_at", "updated_at"). Values(ri.Username, ri.JID, ri.Name, ri.Subscription, groupsBytes, ri.Ask, verExpr, nowExpr, nowExpr). Suffix("ON DUPLICATE KEY UPDATE name = ?, subscription = ?, `groups` = ?, ask = ?, ver = ver + 1, updated_at = NOW()", ri.Name, ri.Subscription, groupsBytes, ri.Ask) - _, err = q.RunWith(tx).Exec() + _, err = q.RunWith(tx).ExecContext(ctx) if err != nil { return err } // delete previous groups _, err = sq.Delete("roster_groups"). Where(sq.And{sq.Eq{"username": ri.Username}, sq.Eq{"jid": ri.JID}}). - RunWith(tx).Exec() + RunWith(tx).ExecContext(ctx) if err != nil { return err } @@ -55,13 +64,13 @@ func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve q = sq.Insert("roster_groups"). Columns("username", "jid", "`group`", "created_at", "updated_at"). Values(ri.Username, ri.JID, group, nowExpr, nowExpr) - _, err := q.RunWith(tx).Exec() + _, err := q.RunWith(tx).ExecContext(ctx) if err != nil { return err } } // fetch new roster version - ver, err = fetchRosterVer(ri.Username, tx) + ver, err = fetchRosterVer(ctx, ri.Username, tx) return err }) if err != nil { @@ -70,36 +79,35 @@ func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve return ver, nil } -// DeleteRosterItem deletes a roster item entity from storage. -func (s *Storage) DeleteRosterItem(username, jid string) (rostermodel.Version, error) { +func (s *mySQLRoster) DeleteRosterItem(ctx context.Context, username, jid string) (rostermodel.Version, error) { var ver rostermodel.Version - err := s.inTransaction(func(tx *sql.Tx) error { + err := s.inTransaction(ctx, func(tx *sql.Tx) error { q := sq.Insert("roster_versions"). Columns("username", "created_at", "updated_at"). Values(username, nowExpr, nowExpr). Suffix("ON DUPLICATE KEY UPDATE ver = ver + 1, last_deletion_ver = ver, updated_at = NOW()") - if _, err := q.RunWith(tx).Exec(); err != nil { + if _, err := q.RunWith(tx).ExecContext(ctx); err != nil { return err } // delete groups _, err := sq.Delete("roster_groups"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"jid": jid}}). - RunWith(tx).Exec() + RunWith(tx).ExecContext(ctx) if err != nil { return err } // delete items _, err = sq.Delete("roster_items"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"jid": jid}}). - RunWith(tx).Exec() + RunWith(tx).ExecContext(ctx) if err != nil { return err } // fetch new roster version - ver, err = fetchRosterVer(username, tx) + ver, err = fetchRosterVer(ctx, username, tx) return err }) if err != nil { @@ -108,65 +116,60 @@ func (s *Storage) DeleteRosterItem(username, jid string) (rostermodel.Version, e return ver, nil } -// FetchRosterItems retrieves from storage all roster item entities -// associated to a given user. -func (s *Storage) FetchRosterItems(username string) ([]rostermodel.Item, rostermodel.Version, error) { +func (s *mySQLRoster) FetchRosterItems(ctx context.Context, username string) ([]rostermodel.Item, rostermodel.Version, error) { q := sq.Select("username", "jid", "name", "subscription", "`groups`", "ask", "ver"). From("roster_items"). Where(sq.Eq{"username": username}). OrderBy("created_at DESC") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, rostermodel.Version{}, err } defer func() { _ = rows.Close() }() - items, err := s.scanRosterItemEntities(rows) + items, err := scanRosterItemEntities(rows) if err != nil { return nil, rostermodel.Version{}, err } - ver, err := fetchRosterVer(username, s.db) + ver, err := fetchRosterVer(ctx, username, s.db) if err != nil { return nil, rostermodel.Version{}, err } return items, ver, nil } -// FetchRosterItemsInGroups retrieves from storage all roster item entities -// associated to a given user and a set of groups. -func (s *Storage) FetchRosterItemsInGroups(username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { +func (s *mySQLRoster) FetchRosterItemsInGroups(ctx context.Context, username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { q := sq.Select("ris.username", "ris.jid", "ris.name", "ris.subscription", "ris.`groups`", "ris.ask", "ris.ver"). From("roster_items ris"). - LeftJoin("roster_groups g on ris.username = g.username"). + LeftJoin("roster_groups g ON ris.username = g.username"). Where(sq.And{sq.Eq{"ris.username": username}, sq.Eq{"g.group": groups}}). OrderBy("ris.created_at DESC") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, rostermodel.Version{}, err } defer func() { _ = rows.Close() }() - items, err := s.scanRosterItemEntities(rows) + items, err := scanRosterItemEntities(rows) if err != nil { return nil, rostermodel.Version{}, err } - ver, err := fetchRosterVer(username, s.db) + ver, err := fetchRosterVer(ctx, username, s.db) if err != nil { return nil, rostermodel.Version{}, err } return items, ver, nil } -// FetchRosterItem retrieves from storage a roster item entity. -func (s *Storage) FetchRosterItem(username, jid string) (*rostermodel.Item, error) { +func (s *mySQLRoster) FetchRosterItem(ctx context.Context, username, jid string) (*rostermodel.Item, error) { q := sq.Select("username", "jid", "name", "subscription", "`groups`", "ask", "ver"). From("roster_items"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"jid": jid}}) var ri rostermodel.Item - err := s.scanRosterItemEntity(&ri, q.RunWith(s.db).QueryRow()) + err := scanRosterItemEntity(&ri, q.RunWith(s.db).QueryRowContext(ctx)) switch err { case nil: return &ri, nil @@ -177,36 +180,32 @@ func (s *Storage) FetchRosterItem(username, jid string) (*rostermodel.Item, erro } } -// InsertOrUpdateRosterNotification inserts a new roster notification entity -// into storage, or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateRosterNotification(rn *rostermodel.Notification) error { +func (s *mySQLRoster) UpsertRosterNotification(ctx context.Context, rn *rostermodel.Notification) error { presenceXML := rn.Presence.String() q := sq.Insert("roster_notifications"). Columns("contact", "jid", "elements", "updated_at", "created_at"). Values(rn.Contact, rn.JID, presenceXML, nowExpr, nowExpr). Suffix("ON DUPLICATE KEY UPDATE elements = ?, updated_at = NOW()", presenceXML) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -// FetchRosterNotifications retrieves from storage all roster notifications -// associated to a given user. -func (s *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { +func (s *mySQLRoster) FetchRosterNotifications(ctx context.Context, contact string) ([]rostermodel.Notification, error) { q := sq.Select("contact", "jid", "elements"). From("roster_notifications"). Where(sq.Eq{"contact": contact}). OrderBy("created_at") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() var ret []rostermodel.Notification for rows.Next() { var rn rostermodel.Notification - if err := s.scanRosterNotificationEntity(&rn, rows); err != nil { + if err := scanRosterNotificationEntity(&rn, rows); err != nil { return nil, err } ret = append(ret, rn) @@ -214,14 +213,13 @@ func (s *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notifi return ret, nil } -// FetchRosterNotification retrieves from storage a roster notification entity. -func (s *Storage) FetchRosterNotification(contact string, jid string) (*rostermodel.Notification, error) { +func (s *mySQLRoster) FetchRosterNotification(ctx context.Context, contact string, jid string) (*rostermodel.Notification, error) { q := sq.Select("contact", "jid", "elements"). From("roster_notifications"). Where(sq.And{sq.Eq{"contact": contact}, sq.Eq{"jid": jid}}) var rn rostermodel.Notification - err := s.scanRosterNotificationEntity(&rn, q.RunWith(s.db).QueryRow()) + err := scanRosterNotificationEntity(&rn, q.RunWith(s.db).QueryRowContext(ctx)) switch err { case nil: return &rn, nil @@ -232,14 +230,36 @@ func (s *Storage) FetchRosterNotification(contact string, jid string) (*rostermo } } -// DeleteRosterNotification deletes a roster notification entity from storage. -func (s *Storage) DeleteRosterNotification(contact, jid string) error { +func (s *mySQLRoster) DeleteRosterNotification(ctx context.Context, contact, jid string) error { q := sq.Delete("roster_notifications").Where(sq.And{sq.Eq{"contact": contact}, sq.Eq{"jid": jid}}) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -func (s *Storage) scanRosterNotificationEntity(rn *rostermodel.Notification, scanner rowScanner) error { +func (s *mySQLRoster) FetchRosterGroups(ctx context.Context, username string) ([]string, error) { + q := sq.Select("`group`"). + From("roster_groups"). + Where(sq.Eq{"username": username}). + GroupBy("`group`") + + rows, err := q.RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var groups []string + for rows.Next() { + var group string + if err := rows.Scan(&group); err != nil { + return nil, err + } + groups = append(groups, group) + } + return groups, nil +} + +func scanRosterNotificationEntity(rn *rostermodel.Notification, scanner rowScanner) error { var presenceXML string if err := scanner.Scan(&rn.Contact, &rn.JID, &presenceXML); err != nil { return err @@ -255,7 +275,7 @@ func (s *Storage) scanRosterNotificationEntity(rn *rostermodel.Notification, sca return nil } -func (s *Storage) scanRosterItemEntity(ri *rostermodel.Item, scanner rowScanner) error { +func scanRosterItemEntity(ri *rostermodel.Item, scanner rowScanner) error { var groupsBytes string if err := scanner.Scan(&ri.Username, &ri.JID, &ri.Name, &ri.Subscription, &groupsBytes, &ri.Ask, &ri.Ver); err != nil { return err @@ -268,11 +288,11 @@ func (s *Storage) scanRosterItemEntity(ri *rostermodel.Item, scanner rowScanner) return nil } -func (s *Storage) scanRosterItemEntities(scanner rowsScanner) ([]rostermodel.Item, error) { +func scanRosterItemEntities(scanner rowsScanner) ([]rostermodel.Item, error) { var ret []rostermodel.Item for scanner.Next() { var ri rostermodel.Item - if err := s.scanRosterItemEntity(&ri, scanner); err != nil { + if err := scanRosterItemEntity(&ri, scanner); err != nil { return nil, err } ret = append(ret, ri) @@ -280,13 +300,13 @@ func (s *Storage) scanRosterItemEntities(scanner rowsScanner) ([]rostermodel.Ite return ret, nil } -func fetchRosterVer(username string, runner sq.BaseRunner) (rostermodel.Version, error) { +func fetchRosterVer(ctx context.Context, username string, runner sq.BaseRunner) (rostermodel.Version, error) { q := sq.Select("IFNULL(MAX(ver), 0)", "IFNULL(MAX(last_deletion_ver), 0)"). From("roster_versions"). Where(sq.Eq{"username": username}) var ver rostermodel.Version - row := q.RunWith(runner).QueryRow() + row := q.RunWith(runner).QueryRowContext(ctx) err := row.Scan(&ver.Ver, &ver.DeletionVer) switch err { case nil: diff --git a/storage/mysql/roster_test.go b/storage/mysql/roster_test.go index 8aa7ff914..d70890e7b 100644 --- a/storage/mysql/roster_test.go +++ b/storage/mysql/roster_test.go @@ -6,12 +6,13 @@ package mysql import ( + "context" "database/sql/driver" "encoding/json" "testing" "github.com/DATA-DOG/go-sqlmock" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/xmpp" "github.com/stretchr/testify/require" ) @@ -43,7 +44,7 @@ func TestMySQLStorageInsertRosterItem(t *testing.T) { ri.Ask, } - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectBegin() mock.ExpectExec("INSERT INTO roster_versions (.+) ON DUPLICATE KEY UPDATE (.+)"). @@ -72,13 +73,13 @@ func TestMySQLStorageInsertRosterItem(t *testing.T) { mock.ExpectCommit() - _, err := s.InsertOrUpdateRosterItem(&ri) + _, err := s.UpsertRosterItem(context.Background(), &ri) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) } func TestMySQLStorageDeleteRosterItem(t *testing.T) { - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectBegin() mock.ExpectExec("INSERT INTO roster_versions (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs("user").WillReturnResult(sqlmock.NewResult(0, 1)) @@ -91,17 +92,17 @@ func TestMySQLStorageDeleteRosterItem(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"ver", "deletionVer"}).AddRow(1, 0)) mock.ExpectCommit() - _, err := s.DeleteRosterItem("user", "contact") + _, err := s.DeleteRosterItem(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectBegin() mock.ExpectExec("INSERT INTO roster_versions (.+)"). WithArgs("user").WillReturnError(errMySQLStorage) mock.ExpectRollback() - _, err = s.DeleteRosterItem("user", "contact") + _, err = s.DeleteRosterItem(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } @@ -109,7 +110,7 @@ func TestMySQLStorageDeleteRosterItem(t *testing.T) { func TestMySQLStorageFetchRosterItems(t *testing.T) { var riColumns = []string{"user", "contact", "name", "subscription", "`groups`", "ask", "ver"} - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(riColumns).AddRow("ortuman", "romeo", "Romeo", "both", "", false, 0)) @@ -117,66 +118,66 @@ func TestMySQLStorageFetchRosterItems(t *testing.T) { WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows([]string{"ver", "deletionVer"}).AddRow(0, 0)) - rosterItems, _, err := s.FetchRosterItems("ortuman") + rosterItems, _, err := s.FetchRosterItems(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 1, len(rosterItems)) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman"). WillReturnError(errMySQLStorage) - _, _, err = s.FetchRosterItems("ortuman") + _, _, err = s.FetchRosterItems(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman", "romeo"). WillReturnRows(sqlmock.NewRows(riColumns).AddRow("ortuman", "romeo", "Romeo", "both", "", false, 0)) - _, err = s.FetchRosterItem("ortuman", "romeo") + _, err = s.FetchRosterItem(context.Background(), "ortuman", "romeo") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman", "romeo"). WillReturnRows(sqlmock.NewRows(riColumns)) - ri, _ := s.FetchRosterItem("ortuman", "romeo") + ri, _ := s.FetchRosterItem(context.Background(), "ortuman", "romeo") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, ri) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman", "romeo"). WillReturnError(errMySQLStorage) - _, err = s.FetchRosterItem("ortuman", "romeo") + _, err = s.FetchRosterItem(context.Background(), "ortuman", "romeo") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman"). WillReturnError(errMySQLStorage) - _, _, err = s.FetchRosterItems("ortuman") + _, _, err = s.FetchRosterItems(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) var riColumns2 = []string{"ris.user", "ris.contact", "ris.name", "ris.subscription", "ris.`groups`", "ris.ask", "ris.ver"} - s, mock = NewMock() - mock.ExpectQuery("SELECT (.+) FROM roster_items ris LEFT JOIN roster_groups g on ris.username = g.username (.+)"). + s, mock = newRosterMock() + mock.ExpectQuery("SELECT (.+) FROM roster_items ris LEFT JOIN roster_groups g ON ris.username = g.username (.+)"). WithArgs("ortuman", "Family"). WillReturnRows(sqlmock.NewRows(riColumns2).AddRow("ortuman", "romeo", "Romeo", "both", `["Family"]`, false, 0)) mock.ExpectQuery("SELECT (.+) FROM roster_versions (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows([]string{"ver", "deletionVer"}).AddRow(0, 0)) - _, _, err = s.FetchRosterItemsInGroups("ortuman", []string{"Family"}) + _, _, err = s.FetchRosterItemsInGroups(context.Background(), "ortuman", []string{"Family"}) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) } @@ -195,39 +196,39 @@ func TestMySQLStorageInsertRosterNotification(t *testing.T) { presenceXML, presenceXML, } - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectExec("INSERT INTO roster_notifications (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs(args...). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdateRosterNotification(&rn) + err := s.UpsertRosterNotification(context.Background(), &rn) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectExec("INSERT INTO roster_notifications (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs(args...). WillReturnError(errMySQLStorage) - err = s.InsertOrUpdateRosterNotification(&rn) + err = s.UpsertRosterNotification(context.Background(), &rn) require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } func TestMySQLStorageDeleteRosterNotification(t *testing.T) { - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectExec("DELETE FROM roster_notifications (.+)"). WithArgs("user", "contact").WillReturnResult(sqlmock.NewResult(0, 1)) - err := s.DeleteRosterNotification("user", "contact") + err := s.DeleteRosterNotification(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectExec("DELETE FROM roster_notifications (.+)"). WithArgs("user", "contact").WillReturnError(errMySQLStorage) - err = s.DeleteRosterNotification("user", "contact") + err = s.DeleteRosterNotification(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } @@ -235,41 +236,77 @@ func TestMySQLStorageDeleteRosterNotification(t *testing.T) { func TestMySQLStorageFetchRosterNotifications(t *testing.T) { var rnColumns = []string{"user", "contact", "elements"} - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(rnColumns).AddRow("romeo", "contact", "8")) - rosterNotifications, err := s.FetchRosterNotifications("ortuman") + rosterNotifications, err := s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 1, len(rosterNotifications)) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(rnColumns)) - rosterNotifications, err = s.FetchRosterNotifications("ortuman") + rosterNotifications, err = s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 0, len(rosterNotifications)) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnError(errMySQLStorage) - _, err = s.FetchRosterNotifications("ortuman") + _, err = s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(rnColumns).AddRow("romeo", "contact", "8")) - _, err = s.FetchRosterNotifications("ortuman") + _, err = s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) } + +func TestMySQLStorageFetchRosterGroups(t *testing.T) { + s, mock := newRosterMock() + mock.ExpectQuery("SELECT `group` FROM roster_groups WHERE username = (.+) GROUP BY (.+)"). + WithArgs("ortuman"). + WillReturnRows(sqlmock.NewRows([]string{"group"}). + AddRow("Contacts"). + AddRow("News")) + + groups, err := s.FetchRosterGroups(context.Background(), "ortuman") + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.Equal(t, 2, len(groups)) + require.Equal(t, "Contacts", groups[0]) + require.Equal(t, "News", groups[1]) + + s, mock = newRosterMock() + mock.ExpectQuery("SELECT `group` FROM roster_groups WHERE username = (.+) GROUP BY (.+)"). + WithArgs("ortuman"). + WillReturnError(errMySQLStorage) + + groups, err = s.FetchRosterGroups(context.Background(), "ortuman") + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, groups) + require.NotNil(t, err) + require.Equal(t, errMySQLStorage, err) +} + +func newRosterMock() (*mySQLRoster, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLRoster{ + mySQLStorage: s, + }, sqlMock +} diff --git a/storage/mysql/sql.go b/storage/mysql/sql.go deleted file mode 100644 index 4b9ab9889..000000000 --- a/storage/mysql/sql.go +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package mysql - -import ( - "database/sql" - "fmt" - "time" - - sq "github.com/Masterminds/squirrel" - _ "github.com/go-sql-driver/mysql" // SQL driver - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/pool" -) - -var ( - nowExpr = sq.Expr("NOW()") -) - -type rowScanner interface { - Scan(...interface{}) error -} - -type rowsScanner interface { - rowScanner - Next() bool -} - -// Storage represents a SQL storage sub system. -type Storage struct { - db *sql.DB - pool *pool.BufferPool - doneCh chan chan bool -} - -// New instantiates a SQL storage instance. -func New(cfg *Config) *Storage { - var err error - s := &Storage{ - pool: pool.NewBufferPool(), - doneCh: make(chan chan bool), - } - host := cfg.Host - user := cfg.User - pass := cfg.Password - db := cfg.Database - poolSize := cfg.PoolSize - - dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true", user, pass, host, db) - s.db, err = sql.Open("mysql", dsn) - if err != nil { - log.Fatalf("%v", err) - } - s.db.SetMaxOpenConns(poolSize) // set max opened connection count - - if err := s.db.Ping(); err != nil { - log.Fatalf("%v", err) - } - go s.loop() - - return s -} - -// IsClusterCompatible returns whether or not the underlying storage subsystem can be used in cluster mode. -func (s *Storage) IsClusterCompatible() bool { return true } - -// Close shuts down SQL storage sub system. -func (s *Storage) Close() error { - ch := make(chan bool) - s.doneCh <- ch - <-ch - return nil -} - -func (s *Storage) loop() { - tc := time.NewTicker(time.Second * 15) - defer tc.Stop() - for { - select { - case <-tc.C: - err := s.db.Ping() - if err != nil { - log.Error(err) - } - case ch := <-s.doneCh: - s.db.Close() - close(ch) - return - } - } -} - -func (s *Storage) inTransaction(f func(tx *sql.Tx) error) error { - tx, txErr := s.db.Begin() - if txErr != nil { - return txErr - } - if err := f(tx); err != nil { - tx.Rollback() - return err - } - return tx.Commit() -} diff --git a/storage/mysql/sql_test.go b/storage/mysql/sql_test.go deleted file mode 100644 index 01e55ec21..000000000 --- a/storage/mysql/sql_test.go +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package mysql - -import ( - "errors" - - sqlmock "github.com/DATA-DOG/go-sqlmock" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/pool" -) - -var ( - errMySQLStorage = errors.New("mysql: storage error") -) - -// NewMock returns a mocked SQL storage instance. -func NewMock() (*Storage, sqlmock.Sqlmock) { - var err error - var sqlMock sqlmock.Sqlmock - s := &Storage{ - pool: pool.NewBufferPool(), - } - s.db, sqlMock, err = sqlmock.New() - if err != nil { - log.Fatalf("%v", err) - } - return s, sqlMock -} diff --git a/storage/mysql/storage.go b/storage/mysql/storage.go new file mode 100644 index 000000000..0a3d003d1 --- /dev/null +++ b/storage/mysql/storage.go @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + "context" + "database/sql" + "errors" + + sq "github.com/Masterminds/squirrel" +) + +var ( + nowExpr = sq.Expr("NOW()") +) + +type rowScanner interface { + Scan(...interface{}) error +} + +type rowsScanner interface { + rowScanner + Next() bool +} + +// mySQLStorage represents a SQL storage sub system. +type mySQLStorage struct { + // DB represents a MySQL database handler. + db *sql.DB +} + +var ( + errMocked = errors.New("mysql: storage error") +) + +func newStorage(db *sql.DB) *mySQLStorage { + return &mySQLStorage{db: db} +} + +func (s *mySQLStorage) inTransaction(ctx context.Context, f func(tx *sql.Tx) error) error { + tx, txErr := s.db.BeginTx(ctx, nil) + if txErr != nil { + return txErr + } + if err := f(tx); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} diff --git a/storage/mysql/storage_test.go b/storage/mysql/storage_test.go new file mode 100644 index 000000000..5fbec33c6 --- /dev/null +++ b/storage/mysql/storage_test.go @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package mysql + +import ( + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/ortuman/jackal/log" +) + +// newMock returns a mocked MySQL storage instance. +func newStorageMock() (*mySQLStorage, sqlmock.Sqlmock) { + db, sqlMock, err := sqlmock.New() + if err != nil { + log.Fatalf("%v", err) + } + return &mySQLStorage{db: db}, sqlMock +} diff --git a/storage/mysql/user.go b/storage/mysql/user.go index 0b32b55a5..6184cdb95 100644 --- a/storage/mysql/user.go +++ b/storage/mysql/user.go @@ -1,33 +1,47 @@ /* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. * See the LICENSE file for more information. */ package mysql import ( + "context" "database/sql" "strings" "time" sq "github.com/Masterminds/squirrel" "github.com/ortuman/jackal/model" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) -// InsertOrUpdateUser inserts a new user entity into storage, -// or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateUser(u *model.User) error { +type mySQLUser struct { + *mySQLStorage + pool *pool.BufferPool +} + +func newUser(db *sql.DB) *mySQLUser { + return &mySQLUser{ + mySQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +func (u *mySQLUser) UpsertUser(ctx context.Context, usr *model.User) error { var presenceXML string - if u.LastPresence != nil { - buf := s.pool.Get() - u.LastPresence.ToXML(buf, true) + if usr.LastPresence != nil { + buf := u.pool.Get() + if err := usr.LastPresence.ToXML(buf, true); err != nil { + return err + } presenceXML = buf.String() - s.pool.Put(buf) + u.pool.Put(buf) } columns := []string{"username", "password", "updated_at", "created_at"} - values := []interface{}{u.Username, u.Password, nowExpr, nowExpr} + values := []interface{}{usr.Username, usr.Password, nowExpr, nowExpr} if len(presenceXML) > 0 { columns = append(columns, []string{"last_presence", "last_presence_at"}...) @@ -37,21 +51,21 @@ func (s *Storage) InsertOrUpdateUser(u *model.User) error { var suffixArgs []interface{} if len(presenceXML) > 0 { suffix = "ON DUPLICATE KEY UPDATE password = ?, last_presence = ?, last_presence_at = NOW(), updated_at = NOW()" - suffixArgs = []interface{}{u.Password, presenceXML} + suffixArgs = []interface{}{usr.Password, presenceXML} } else { suffix = "ON DUPLICATE KEY UPDATE password = ?, updated_at = NOW()" - suffixArgs = []interface{}{u.Password} + suffixArgs = []interface{}{usr.Password} } q := sq.Insert("users"). Columns(columns...). Values(values...). Suffix(suffix, suffixArgs...) - _, err := q.RunWith(s.db).Exec() + + _, err := q.RunWith(u.db).ExecContext(ctx) return err } -// FetchUser retrieves from storage a user entity. -func (s *Storage) FetchUser(username string) (*model.User, error) { +func (u *mySQLUser) FetchUser(ctx context.Context, username string) (*model.User, error) { q := sq.Select("username", "password", "last_presence", "last_presence_at"). From("users"). Where(sq.Eq{"username": username}) @@ -60,7 +74,9 @@ func (s *Storage) FetchUser(username string) (*model.User, error) { var presenceAt time.Time var usr model.User - err := q.RunWith(s.db).QueryRow().Scan(&usr.Username, &usr.Password, &presenceXML, &presenceAt) + err := q.RunWith(u.db). + QueryRowContext(ctx). + Scan(&usr.Username, &usr.Password, &presenceXML, &presenceAt) switch err { case nil: if len(presenceXML) > 0 { @@ -82,31 +98,30 @@ func (s *Storage) FetchUser(username string) (*model.User, error) { } } -// DeleteUser deletes a user entity from storage. -func (s *Storage) DeleteUser(username string) error { - return s.inTransaction(func(tx *sql.Tx) error { +func (u *mySQLUser) DeleteUser(ctx context.Context, username string) error { + return u.inTransaction(ctx, func(tx *sql.Tx) error { var err error - _, err = sq.Delete("offline_messages").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("offline_messages").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("roster_items").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("roster_items").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("roster_versions").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("roster_versions").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("private_storage").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("private_storage").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("vcards").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("vcards").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("users").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("users").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } @@ -114,11 +129,13 @@ func (s *Storage) DeleteUser(username string) error { }) } -// UserExists returns whether or not a user exists within storage. -func (s *Storage) UserExists(username string) (bool, error) { - q := sq.Select("COUNT(*)").From("users").Where(sq.Eq{"username": username}) +func (u *mySQLUser) UserExists(ctx context.Context, username string) (bool, error) { + q := sq.Select("COUNT(*)"). + From("users"). + Where(sq.Eq{"username": username}) + var count int - err := q.RunWith(s.db).QueryRow().Scan(&count) + err := q.RunWith(u.db).QueryRowContext(ctx).Scan(&count) switch err { case nil: return count > 0, nil diff --git a/storage/mysql/user_test.go b/storage/mysql/user_test.go index 4f53b48b9..213bc151b 100644 --- a/storage/mysql/user_test.go +++ b/storage/mysql/user_test.go @@ -6,11 +6,13 @@ package mysql import ( + "context" "testing" "time" sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/ortuman/jackal/model" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/stretchr/testify/require" @@ -23,26 +25,27 @@ func TestMySQLStorageInsertUser(t *testing.T) { user := model.User{Username: "ortuman", Password: "1234", LastPresence: p} - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectExec("INSERT INTO users (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs("ortuman", "1234", p.String(), "1234", p.String()). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdateUser(&user) + err := s.UpsertUser(context.Background(), &user) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectExec("INSERT INTO users (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs("ortuman", "1234", p.String(), "1234", p.String()). - WillReturnError(errMySQLStorage) - err = s.InsertOrUpdateUser(&user) + WillReturnError(errMocked) + + err = s.UpsertUser(context.Background(), &user) require.Nil(t, mock.ExpectationsWereMet()) - require.Equal(t, errMySQLStorage, err) + require.Equal(t, errMocked, err) } func TestMySQLStorageDeleteUser(t *testing.T) { - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectBegin() mock.ExpectExec("DELETE FROM offline_messages (.+)"). WithArgs("ortuman").WillReturnResult(sqlmock.NewResult(0, 1)) @@ -58,19 +61,19 @@ func TestMySQLStorageDeleteUser(t *testing.T) { WithArgs("ortuman").WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() - err := s.DeleteUser("ortuman") + err := s.DeleteUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectBegin() mock.ExpectExec("DELETE FROM offline_messages (.+)"). - WithArgs("ortuman").WillReturnError(errMySQLStorage) + WithArgs("ortuman").WillReturnError(errMocked) mock.ExpectRollback() - err = s.DeleteUser("ortuman") + err = s.DeleteUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) - require.Equal(t, errMySQLStorage, err) + require.Equal(t, errMocked, err) } func TestMySQLStorageFetchUser(t *testing.T) { @@ -80,49 +83,57 @@ func TestMySQLStorageFetchUser(t *testing.T) { var userColumns = []string{"username", "password", "last_presence", "last_presence_at"} - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectQuery("SELECT (.+) FROM users (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(userColumns)) - usr, _ := s.FetchUser("ortuman") + usr, _ := s.FetchUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, usr) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectQuery("SELECT (.+) FROM users (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(userColumns).AddRow("ortuman", "1234", p.String(), time.Now())) - _, err := s.FetchUser("ortuman") + _, err := s.FetchUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectQuery("SELECT (.+) FROM users (.+)"). - WithArgs("ortuman").WillReturnError(errMySQLStorage) - _, err = s.FetchUser("ortuman") + WithArgs("ortuman").WillReturnError(errMocked) + _, err = s.FetchUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) - require.Equal(t, errMySQLStorage, err) + require.Equal(t, errMocked, err) } func TestMySQLStorageUserExists(t *testing.T) { - countColums := []string{"count"} + countCols := []string{"count"} - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectQuery("SELECT COUNT(.+) FROM users (.+)"). WithArgs("ortuman"). - WillReturnRows(sqlmock.NewRows(countColums).AddRow(1)) + WillReturnRows(sqlmock.NewRows(countCols).AddRow(1)) - ok, err := s.UserExists("ortuman") + ok, err := s.UserExists(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.True(t, ok) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectQuery("SELECT COUNT(.+) FROM users (.+)"). WithArgs("romeo"). - WillReturnError(errMySQLStorage) - _, err = s.UserExists("romeo") + WillReturnError(errMocked) + _, err = s.UserExists(context.Background(), "romeo") require.Nil(t, mock.ExpectationsWereMet()) - require.Equal(t, errMySQLStorage, err) + require.Equal(t, errMocked, err) +} + +func newUserMock() (*mySQLUser, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLUser{ + mySQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock } diff --git a/storage/mysql/vcard.go b/storage/mysql/vcard.go index 6c4479bf1..bf5b8836a 100644 --- a/storage/mysql/vcard.go +++ b/storage/mysql/vcard.go @@ -6,6 +6,7 @@ package mysql import ( + "context" "database/sql" "strings" @@ -13,26 +14,35 @@ import ( "github.com/ortuman/jackal/xmpp" ) -// InsertOrUpdateVCard inserts a new vCard element into storage, -// or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateVCard(vCard xmpp.XElement, username string) error { +type mySQLVCard struct { + *mySQLStorage +} + +func newVCard(db *sql.DB) *mySQLVCard { + return &mySQLVCard{ + mySQLStorage: newStorage(db), + } +} + +// UpsertVCard inserts a new vCard element into storage, or updates it in case it's been previously inserted. +func (s *mySQLVCard) UpsertVCard(ctx context.Context, vCard xmpp.XElement, username string) error { rawXML := vCard.String() q := sq.Insert("vcards"). Columns("username", "vcard", "updated_at", "created_at"). Values(username, rawXML, nowExpr, nowExpr). Suffix("ON DUPLICATE KEY UPDATE vcard = ?, updated_at = NOW()", rawXML) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -// FetchVCard retrieves from storage a vCard element associated -// to a given user. -func (s *Storage) FetchVCard(username string) (xmpp.XElement, error) { +// FetchVCard retrieves from storage a vCard element associated to a given user. +func (s *mySQLVCard) FetchVCard(ctx context.Context, username string) (xmpp.XElement, error) { + var vCard string + q := sq.Select("vcard").From("vcards").Where(sq.Eq{"username": username}) - var vCard string - err := q.RunWith(s.db).QueryRow().Scan(&vCard) + err := q.RunWith(s.db).QueryRowContext(ctx).Scan(&vCard) switch err { case nil: parser := xmpp.NewParser(strings.NewReader(vCard), xmpp.DefaultMode, 0) diff --git a/storage/mysql/vcard_test.go b/storage/mysql/vcard_test.go index 321ca08c8..626e43922 100644 --- a/storage/mysql/vcard_test.go +++ b/storage/mysql/vcard_test.go @@ -6,6 +6,7 @@ package mysql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" @@ -17,22 +18,22 @@ func TestMySQLStorageInsertVCard(t *testing.T) { vCard := xmpp.NewElementName("vCard") rawXML := vCard.String() - s, mock := NewMock() + s, mock := newVCardMock() mock.ExpectExec("INSERT INTO vcards (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs("ortuman", rawXML, rawXML). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdateVCard(vCard, "ortuman") + err := s.UpsertVCard(context.Background(), vCard, "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.NotNil(t, vCard) - s, mock = NewMock() + s, mock = newVCardMock() mock.ExpectExec("INSERT INTO vcards (.+) ON DUPLICATE KEY UPDATE (.+)"). WithArgs("ortuman", rawXML, rawXML). WillReturnError(errMySQLStorage) - err = s.InsertOrUpdateVCard(vCard, "ortuman") + err = s.UpsertVCard(context.Background(), vCard, "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errMySQLStorage, err) } @@ -40,32 +41,39 @@ func TestMySQLStorageInsertVCard(t *testing.T) { func TestMySQLStorageFetchVCard(t *testing.T) { var vCardColumns = []string{"vcard"} - s, mock := NewMock() + s, mock := newVCardMock() mock.ExpectQuery("SELECT (.+) FROM vcards (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(vCardColumns).AddRow("Miguel Ɓngel")) - vCard, err := s.FetchVCard("ortuman") + vCard, err := s.FetchVCard(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.NotNil(t, vCard) - s, mock = NewMock() + s, mock = newVCardMock() mock.ExpectQuery("SELECT (.+) FROM vcards (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(vCardColumns)) - vCard, err = s.FetchVCard("ortuman") + vCard, err = s.FetchVCard(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Nil(t, vCard) - s, mock = NewMock() + s, mock = newVCardMock() mock.ExpectQuery("SELECT (.+) FROM vcards (.+)"). WithArgs("ortuman"). WillReturnError(errMySQLStorage) - vCard, _ = s.FetchVCard("ortuman") + vCard, _ = s.FetchVCard(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, vCard) } + +func newVCardMock() (*mySQLVCard, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &mySQLVCard{ + mySQLStorage: s, + }, sqlMock +} diff --git a/storage/offline.go b/storage/offline.go deleted file mode 100644 index 04a87f1cc..000000000 --- a/storage/offline.go +++ /dev/null @@ -1,32 +0,0 @@ -package storage - -import "github.com/ortuman/jackal/xmpp" - -// offlineStorage defines storage operations for offline messages -type offlineStorage interface { - InsertOfflineMessage(message *xmpp.Message, username string) error - CountOfflineMessages(username string) (int, error) - FetchOfflineMessages(username string) ([]xmpp.Message, error) - DeleteOfflineMessages(username string) error -} - -// InsertOfflineMessage inserts a new message element into -// user's offline queue. -func InsertOfflineMessage(message *xmpp.Message, username string) error { - return instance().InsertOfflineMessage(message, username) -} - -// CountOfflineMessages returns current length of user's offline queue. -func CountOfflineMessages(username string) (int, error) { - return instance().CountOfflineMessages(username) -} - -// FetchOfflineMessages retrieves from storage current user offline queue. -func FetchOfflineMessages(username string) ([]xmpp.Message, error) { - return instance().FetchOfflineMessages(username) -} - -// DeleteOfflineMessages clears a user offline queue. -func DeleteOfflineMessages(username string) error { - return instance().DeleteOfflineMessages(username) -} diff --git a/storage/pgsql/block_list.go b/storage/pgsql/block_list.go index c49cde383..cafdcf2f5 100644 --- a/storage/pgsql/block_list.go +++ b/storage/pgsql/block_list.go @@ -6,73 +6,64 @@ package pgsql import ( + "context" "database/sql" sq "github.com/Masterminds/squirrel" "github.com/ortuman/jackal/model" ) -// InsertBlockListItems inserts a set of block list item entities -// into storage, only in case they haven't been previously inserted. -func (s *Storage) InsertBlockListItems(items []model.BlockListItem) error { - return s.inTransaction(func(tx *sql.Tx) error { - for _, item := range items { - q := sq.Insert("blocklist_items"). - Columns("username", "jid"). - Values(item.Username, item.JID). - RunWith(tx) +type pgSQLBlockList struct { + *pgSQLStorage +} - if _, err := q.Exec(); err != nil { - return err - } - } - return nil - }) +func newBlockList(db *sql.DB) *pgSQLBlockList { + return &pgSQLBlockList{ + pgSQLStorage: newStorage(db), + } } -// DeleteBlockListItems deletes a set of block list item entities from storage. -func (s *Storage) DeleteBlockListItems(items []model.BlockListItem) error { - return s.inTransaction(func(tx *sql.Tx) error { - for _, item := range items { - q := sq.Delete("blocklist_items"). - Where(sq.And{sq.Eq{"username": item.Username}, sq.Eq{"jid": item.JID}}). - RunWith(tx) +func (s *pgSQLBlockList) InsertBlockListItem(ctx context.Context, item *model.BlockListItem) error { + q := sq.Insert("blocklist_items"). + Columns("username", "jid"). + Values(item.Username, item.JID). + RunWith(s.db) + _, err := q.ExecContext(ctx) + return err +} - if _, err := q.Exec(); err != nil { - return err - } - } - return nil - }) +func (s *pgSQLBlockList) DeleteBlockListItem(ctx context.Context, item *model.BlockListItem) error { + q := sq.Delete("blocklist_items"). + Where(sq.And{sq.Eq{"username": item.Username}, sq.Eq{"jid": item.JID}}). + RunWith(s.db) + _, err := q.ExecContext(ctx) + return err } -// FetchBlockListItems retrieves from storage all block list item entities -// associated to a given user. -func (s *Storage) FetchBlockListItems(username string) ([]model.BlockListItem, error) { +func (s *pgSQLBlockList) FetchBlockListItems(ctx context.Context, username string) ([]model.BlockListItem, error) { q := sq.Select("username", "jid"). From("blocklist_items"). Where(sq.Eq{"username": username}). OrderBy("created_at") - rows, err := q.RunWith(s.db).Query() - + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, err } + defer func() { _ = rows.Close() }() - defer rows.Close() - - return s.scanBlockListItemEntities(rows) + return scanBlockListItemEntities(rows) } -func (s *Storage) scanBlockListItemEntities(scanner rowsScanner) ([]model.BlockListItem, error) { +func scanBlockListItemEntities(scanner rowsScanner) ([]model.BlockListItem, error) { var ret []model.BlockListItem for scanner.Next() { var it model.BlockListItem - scanner.Scan(&it.Username, &it.JID) + if err := scanner.Scan(&it.Username, &it.JID); err != nil { + return nil, err + } ret = append(ret, it) } - return ret, nil } diff --git a/storage/pgsql/block_list_test.go b/storage/pgsql/block_list_test.go index 812285ef3..89fc37d92 100644 --- a/storage/pgsql/block_list_test.go +++ b/storage/pgsql/block_list_test.go @@ -6,6 +6,7 @@ package pgsql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" @@ -13,45 +14,34 @@ import ( "github.com/stretchr/testify/require" ) -const ( - blockListInsert = "INSERT INTO blocklist_items (.+)" - blockListDelete = "DELETE FROM blocklist_items (.+)" - blockListSelect = "SELECT (.+) FROM blocklist_items (.+)" -) - // Insert a valid block list item func TestInsertValidBlockListItem(t *testing.T) { - s, mock := NewMock() - items := []model.BlockListItem{{Username: "ortuman", JID: "noelia@jackal.im"}} + s, mock := newBlockListMock() - mock.ExpectBegin() - mock.ExpectExec(blockListInsert).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.ExpectExec("INSERT INTO blocklist_items (.+)"). + WillReturnResult(sqlmock.NewResult(0, 1)) - err := s.InsertBlockListItems(items) + err := s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Nil(t, err) require.Nil(t, mock.ExpectationsWereMet()) } // Insert the same row twice to test for key uniqueness validation func TestInsertDoubleBlockListItem(t *testing.T) { - s, mock := NewMock() - items := []model.BlockListItem{{Username: "ortuman", JID: "noelia@jackal.im"}} + s, mock := newBlockListMock() // First insertion will be successful - mock.ExpectBegin() - mock.ExpectExec(blockListInsert).WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() + mock.ExpectExec("INSERT INTO blocklist_items (.+)"). + WillReturnResult(sqlmock.NewResult(0, 1)) // Second insertion will fail - mock.ExpectBegin() - mock.ExpectExec(blockListInsert).WillReturnError(errGeneric) - mock.ExpectRollback() + mock.ExpectExec("INSERT INTO blocklist_items (.+)"). + WillReturnError(errGeneric) - err := s.InsertBlockListItems(items) + err := s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Nil(t, err) - err = s.InsertBlockListItems(items) + err = s.InsertBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Equal(t, errGeneric, err) require.Nil(t, mock.ExpectationsWereMet()) } @@ -59,55 +49,59 @@ func TestInsertDoubleBlockListItem(t *testing.T) { // Test fetching block list items func TestFetchBlockListItems(t *testing.T) { var blockListColumns = []string{"username", "jid"} - s, mock := NewMock() + s, mock := newBlockListMock() - mock.ExpectQuery(blockListSelect).WithArgs("ortuman"). + mock.ExpectQuery("SELECT (.+) FROM blocklist_items (.+)"). + WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(blockListColumns).AddRow("ortuman", "noelia@jackal.im")) - _, err := s.FetchBlockListItems("ortuman") + _, err := s.FetchBlockListItems(context.Background(), "ortuman") require.Nil(t, err) require.Nil(t, mock.ExpectationsWereMet()) } // Test error handling on fetching block list items func TestFetchBlockListItemsError(t *testing.T) { - s, mock := NewMock() + s, mock := newBlockListMock() - mock.ExpectQuery(blockListSelect). + mock.ExpectQuery("SELECT (.+) FROM blocklist_items (.+)"). WithArgs("ortuman"). WillReturnError(errGeneric) - _, err := s.FetchBlockListItems("ortuman") + _, err := s.FetchBlockListItems(context.Background(), "ortuman") require.Equal(t, errGeneric, err) require.Nil(t, mock.ExpectationsWereMet()) } // Test deleting an item from the block list func TestDeleteBlockListItems(t *testing.T) { - s, mock := NewMock() - item := model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"} + s, mock := newBlockListMock() - mock.ExpectBegin() - mock.ExpectExec(blockListDelete). - WithArgs(item.Username, item.JID). + mock.ExpectExec("DELETE FROM blocklist_items (.+)"). + WithArgs("ortuman", "noelia@jackal.im"). WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - err := s.DeleteBlockListItems([]model.BlockListItem{item}) + err := s.DeleteBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Nil(t, err) require.Nil(t, mock.ExpectationsWereMet()) } // Test error handling on deleting a row from the block list func TestDeleteBlockListItemsError(t *testing.T) { - s, mock := NewMock() - items := []model.BlockListItem{{Username: "ortuman", JID: "noelia@jackal.im"}} + s, mock := newBlockListMock() - mock.ExpectBegin() - mock.ExpectExec(blockListDelete).WillReturnError(errGeneric) - mock.ExpectRollback() + mock.ExpectExec("DELETE FROM blocklist_items (.+)"). + WithArgs("ortuman", "noelia@jackal.im"). + WillReturnError(errGeneric) - err := s.DeleteBlockListItems(items) + err := s.DeleteBlockListItem(context.Background(), &model.BlockListItem{Username: "ortuman", JID: "noelia@jackal.im"}) require.Equal(t, errGeneric, err) require.Nil(t, mock.ExpectationsWereMet()) } + +func newBlockListMock() (*pgSQLBlockList, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLBlockList{ + pgSQLStorage: s, + }, sqlMock +} diff --git a/storage/pgsql/config.go b/storage/pgsql/config.go index f765f1d98..ec5823597 100644 --- a/storage/pgsql/config.go +++ b/storage/pgsql/config.go @@ -22,11 +22,9 @@ func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { PoolSize: defaultPoolSize, SSLMode: defaultSSLMode, } - if err := unmarshal(&parsed); err != nil { return err } - *c = Config(parsed) return nil diff --git a/storage/pgsql/occupant.go b/storage/pgsql/occupant.go new file mode 100644 index 000000000..324d72935 --- /dev/null +++ b/storage/pgsql/occupant.go @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "database/sql" + + sq "github.com/Masterminds/squirrel" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +type pgSQLOccupant struct { + *pgSQLStorage +} + +func newOccupant(db *sql.DB) *pgSQLOccupant { + return &pgSQLOccupant{ + pgSQLStorage: newStorage(db), + } +} + +func (o *pgSQLOccupant) UpsertOccupant(ctx context.Context, occ *mucmodel.Occupant) error { + return o.inTransaction(ctx, func(tx *sql.Tx) error { + // store occupants data (except for resources) + columns := []string{"occupant_jid", "bare_jid", "affiliation", "role"} + values := []interface{}{occ.OccupantJID.String(), occ.BareJID.String(), + occ.GetAffiliation(), occ.GetRole()} + q := sq.Insert("occupants"). + Columns(columns...). + Values(values...). + Suffix("ON CONFLICT (occupant_jid) DO UPDATE SET affiliation = $3, role = $4") + + _, err := q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + //store occupants resources + columns = []string{"occupant_jid", "resource"} + for _, res := range occ.GetAllResources() { + values = []interface{}{occ.OccupantJID.String(), res} + q = sq.Insert("resources"). + Columns(columns...). + Values(values...). + Suffix("ON CONFLICT (occupant_jid) DO UPDATE SET resource = $2") + + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + + return nil + }) +} + +func (o *pgSQLOccupant) DeleteOccupant(ctx context.Context, occJID *jid.JID) error { + return o.inTransaction(ctx, func(tx *sql.Tx) error { + _, err := sq.Delete("occupants").Where(sq.Eq{"occupant_jid": occJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("resources").Where(sq.Eq{"occupant_jid": occJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + return nil + }) +} + +func (o *pgSQLOccupant) FetchOccupant(ctx context.Context, occJID *jid.JID) (*mucmodel.Occupant, + error) { + tx, err := o.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + occ, err := fetchOccupantData(ctx, tx, occJID) + switch err { + case nil: + case sql.ErrNoRows: + _ = tx.Commit() + return nil, nil + default: + _ = tx.Rollback() + return nil, err + + } + + err = fetchOccupantResources(ctx, tx, occ, occJID) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return occ, nil +} + +func fetchOccupantData(ctx context.Context, tx *sql.Tx, occJID *jid.JID) (*mucmodel.Occupant, + error) { + var occ *mucmodel.Occupant + q := sq.Select("occupant_jid", "bare_jid", "affiliation", "role"). + From("occupants"). + Where(sq.Eq{"occupant_jid": occJID.String()}) + + var occJIDStr, bareJIDStr, affiliation, role string + err := q.RunWith(tx). + QueryRowContext(ctx). + Scan(&occJIDStr, &bareJIDStr, &affiliation, &role) + switch err { + case nil: + occJIDdb, err := jid.NewWithString(occJIDStr, false) + if err != nil { + return nil, err + } + bareJID, err := jid.NewWithString(bareJIDStr, false) + if err != nil { + return nil, err + } + occ, err = mucmodel.NewOccupant(occJIDdb, bareJID) + if err != nil { + return nil, err + } + err = occ.SetAffiliation(affiliation) + if err != nil { + return nil, err + } + err = occ.SetRole(role) + if err != nil { + return nil, err + } + default: + return nil, err + } + return occ, nil +} + +func fetchOccupantResources(ctx context.Context, tx *sql.Tx, occ *mucmodel.Occupant, + occJID *jid.JID) error { + resources, err := sq.Select("occupant_jid", "resource"). + From("resources"). + Where(sq.Eq{"occupant_jid": occJID.String()}). + RunWith(tx).QueryContext(ctx) + if err != nil { + return err + } + for resources.Next() { + var dummy, res string + if err = resources.Scan(&dummy, &res); err != nil { + return err + } + occ.AddResource(res) + } + return nil +} + +func (o *pgSQLOccupant) OccupantExists(ctx context.Context, occJID *jid.JID) (bool, error) { + q := sq.Select("COUNT(*)"). + From("occupants"). + Where(sq.Eq{"occupant_jid": occJID.String()}) + + var count int + err := q.RunWith(o.db).QueryRowContext(ctx).Scan(&count) + switch err { + case nil: + return count > 0, nil + default: + return false, err + } +} diff --git a/storage/pgsql/occupant_test.go b/storage/pgsql/occupant_test.go new file mode 100644 index 000000000..e0463d6f7 --- /dev/null +++ b/storage/pgsql/occupant_test.go @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestPgSQLStorageInsertOccupant(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + o, _ := mucmodel.NewOccupant(j, j.ToBareJID()) + o.AddResource("yard") + o.SetAffiliation("owner") + o.SetRole("moderator") + + s, mock := newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO occupants (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(o.OccupantJID.String(), o.BareJID.String(), o.GetAffiliation(), o.GetRole()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO resources (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(o.OccupantJID.String(), "yard"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := s.UpsertOccupant(context.Background(), o) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO occupants (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(o.OccupantJID.String(), o.BareJID.String(), o.GetAffiliation(), o.GetRole()). + WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.UpsertOccupant(context.Background(), o) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, err, errMocked) +} + +func TestPgSQLStorageDeleteOccupant(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + s, mock := newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM occupants (.+)"). + WithArgs(j.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM resources (.+)"). + WithArgs(j.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.DeleteOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM occupants (.+)"). + WithArgs(j.String()).WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.DeleteOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestPgSQLStorageFetchOccupant(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + + occColumns := []string{"occupant_jid", "bare_jid", "affiliation", "role"} + resColumns := []string{"occupant_jid", "resource"} + + s, mock := newOccupantMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(occColumns)) + mock.ExpectCommit() + + occ, _ := s.FetchOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, occ) + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(occColumns). + AddRow(j.String(), j.ToBareJID().String(), "owner", "moderator")) + mock.ExpectQuery("SELECT (.+) FROM resources (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(resColumns). + AddRow(j.String(), "phone")) + mock.ExpectCommit() + occ, err := s.FetchOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.NotNil(t, occ) + require.Equal(t, occ.OccupantJID.String(), j.String()) + require.Equal(t, occ.BareJID.String(), j.ToBareJID().String()) + require.Equal(t, occ.GetAffiliation(), "owner") + require.Equal(t, occ.GetRole(), "moderator") + require.Len(t, occ.GetAllResources(), 1) + require.Equal(t, occ.GetAllResources()[0], "phone") + + s, mock = newOccupantMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM occupants (.+)"). + WithArgs(j.String()).WillReturnError(errMocked) + mock.ExpectRollback() + _, err = s.FetchOccupant(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestPgSQLStorageOccupantExists(t *testing.T) { + j, _ := jid.NewWithString("room@conference.jackal.im/nick", true) + countCols := []string{"count"} + + s, mock := newOccupantMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnRows(sqlmock.NewRows(countCols).AddRow(1)) + + ok, err := s.OccupantExists(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.True(t, ok) + + s, mock = newOccupantMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM occupants (.+)"). + WithArgs(j.String()). + WillReturnError(errMocked) + _, err = s.OccupantExists(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func newOccupantMock() (*pgSQLOccupant, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLOccupant{ + pgSQLStorage: s, + }, sqlMock +} diff --git a/storage/pgsql/offline.go b/storage/pgsql/offline.go index 5663a62c9..1a3a6cdc1 100644 --- a/storage/pgsql/offline.go +++ b/storage/pgsql/offline.go @@ -6,47 +6,61 @@ package pgsql import ( + "context" + "database/sql" + sq "github.com/Masterminds/squirrel" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) -// InsertOfflineMessage inserts a new message element into -// user's offline queue. -func (s *Storage) InsertOfflineMessage(message *xmpp.Message, username string) error { +type pgSQLOffline struct { + *pgSQLStorage + pool *pool.BufferPool +} + +func newOffline(db *sql.DB) *pgSQLOffline { + return &pgSQLOffline{ + pgSQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +// InsertOfflineMessage inserts a new message element into user's offline queue. +func (s *pgSQLOffline) InsertOfflineMessage(ctx context.Context, message *xmpp.Message, username string) error { q := sq.Insert("offline_messages"). Columns("username", "data"). Values(username, message.String()) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } // CountOfflineMessages returns current length of user's offline queue. -func (s *Storage) CountOfflineMessages(username string) (int, error) { +func (s *pgSQLOffline) CountOfflineMessages(ctx context.Context, username string) (int, error) { + var count int + q := sq.Select("COUNT(*)"). From("offline_messages"). Where(sq.Eq{"username": username}). OrderBy("created_at") - var count int - - if err := q.RunWith(s.db).Scan(&count); err != nil { + if err := q.RunWith(s.db).QueryRowContext(ctx).Scan(&count); err != nil { return 0, err } - return count, nil } // FetchOfflineMessages retrieves from storage current user offline queue. -func (s *Storage) FetchOfflineMessages(username string) ([]xmpp.Message, error) { +func (s *pgSQLOffline) FetchOfflineMessages(ctx context.Context, username string) ([]xmpp.Message, error) { q := sq.Select("data"). From("offline_messages"). Where(sq.Eq{"username": username}). OrderBy("created_at") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, err @@ -72,24 +86,24 @@ func (s *Storage) FetchOfflineMessages(username string) ([]xmpp.Message, error) return nil, err } - elems := rootEl.Elements().All() + elements := rootEl.Elements().All() - var msgs []xmpp.Message - for _, el := range elems { + messages := make([]xmpp.Message, len(elements)) + for i, el := range elements { fromJID, _ := jid.NewWithString(el.From(), true) toJID, _ := jid.NewWithString(el.To(), true) msg, err := xmpp.NewMessageFromElement(el, fromJID, toJID) if err != nil { return nil, err } - msgs = append(msgs, *msg) + messages[i] = *msg } - return msgs, nil + return messages, nil } // DeleteOfflineMessages clears a user offline queue. -func (s *Storage) DeleteOfflineMessages(username string) error { +func (s *pgSQLOffline) DeleteOfflineMessages(ctx context.Context, username string) error { q := sq.Delete("offline_messages").Where(sq.Eq{"username": username}) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } diff --git a/storage/pgsql/offline_test.go b/storage/pgsql/offline_test.go index d9abb0471..d900b65c5 100644 --- a/storage/pgsql/offline_test.go +++ b/storage/pgsql/offline_test.go @@ -6,9 +6,11 @@ package pgsql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/pborman/uuid" @@ -23,21 +25,21 @@ func TestInsertOfflineMessages(t *testing.T) { m, _ := xmpp.NewMessageFromElement(message, j, j) messageXML := m.String() - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectExec("INSERT INTO offline_messages (.+)"). WithArgs("ortuman", messageXML). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOfflineMessage(m, "ortuman") + err := s.InsertOfflineMessage(context.Background(), m, "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectExec("INSERT INTO offline_messages (.+)"). WithArgs("ortuman", messageXML). WillReturnError(errGeneric) - err = s.InsertOfflineMessage(m, "ortuman") + err = s.InsertOfflineMessage(context.Background(), m, "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) } @@ -45,30 +47,30 @@ func TestInsertOfflineMessages(t *testing.T) { func TestCountOfflineMessages(t *testing.T) { countColums := []string{"count"} - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectQuery("SELECT COUNT(.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(countColums).AddRow(1)) - cnt, _ := s.CountOfflineMessages("ortuman") + cnt, _ := s.CountOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 1, cnt) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT COUNT(.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(countColums)) - cnt, _ = s.CountOfflineMessages("ortuman") + cnt, _ = s.CountOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 0, cnt) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT COUNT(.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnError(errGeneric) - _, err := s.CountOfflineMessages("ortuman") + _, err := s.CountOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) } @@ -76,57 +78,65 @@ func TestCountOfflineMessages(t *testing.T) { func TestFetchOfflineMessages(t *testing.T) { var offlineMessagesColumns = []string{"data"} - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(offlineMessagesColumns).AddRow("Hi!")) - msgs, _ := s.FetchOfflineMessages("ortuman") + msgs, _ := s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 1, len(msgs)) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(offlineMessagesColumns)) - msgs, _ = s.FetchOfflineMessages("ortuman") + msgs, _ = s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, 0, len(msgs)) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(offlineMessagesColumns).AddRow("Hi!")) - _, err := s.FetchOfflineMessages("ortuman") + _, err := s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectQuery("SELECT (.+) FROM offline_messages (.+)"). WithArgs("ortuman"). WillReturnError(errGeneric) - _, err = s.FetchOfflineMessages("ortuman") + _, err = s.FetchOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) } func TestDeleteOfflineMessages(t *testing.T) { - s, mock := NewMock() + s, mock := newOfflineMock() mock.ExpectExec("DELETE FROM offline_messages (.+)"). WithArgs("ortuman").WillReturnResult(sqlmock.NewResult(0, 1)) - err := s.DeleteOfflineMessages("ortuman") + err := s.DeleteOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newOfflineMock() mock.ExpectExec("DELETE FROM offline_messages (.+)"). WithArgs("ortuman").WillReturnError(errGeneric) - err = s.DeleteOfflineMessages("ortuman") + err = s.DeleteOfflineMessages(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) } + +func newOfflineMock() (*pgSQLOffline, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLOffline{ + pgSQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock +} diff --git a/storage/pgsql/pgsql.go b/storage/pgsql/pgsql.go new file mode 100644 index 000000000..8524c631d --- /dev/null +++ b/storage/pgsql/pgsql.go @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "database/sql" + "fmt" + "time" + + sq "github.com/Masterminds/squirrel" + _ "github.com/lib/pq" // PostgreSQL driver + "github.com/ortuman/jackal/log" + "github.com/ortuman/jackal/storage/repository" +) + +// pingInterval defines how often to check the connection +var pingInterval = 15 * time.Second + +// pingTimeout defines how long to wait for pong from server +var pingTimeout = 10 * time.Second + +type pgSQLContainer struct { + user *pgSQLUser + roster *pgSQLRoster + presences *pgSQLPresences + vCard *pgSQLVCard + priv *pgSQLPrivate + blockList *pgSQLBlockList + pubSub *pgSQLPubSub + offline *pgSQLOffline + room *pgSQLRoom + occupant *pgSQLOccupant + + h *sql.DB + cancelPing context.CancelFunc + doneCh chan chan bool +} + +// New initializes PgSQL storage and returns associated container. +func New(cfg *Config) (repository.Container, error) { + c := &pgSQLContainer{doneCh: make(chan chan bool, 1)} + + var err error + + sq.StatementBuilder = sq.StatementBuilder.PlaceholderFormat(sq.Dollar) + + dsn := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s", cfg.User, cfg.Password, cfg.Host, cfg.Database, cfg.SSLMode) + + c.h, err = sql.Open("postgres", dsn) + if err != nil { + return nil, err + } + c.h.SetMaxOpenConns(cfg.PoolSize) // set max opened connection count + + if err := c.ping(context.Background()); err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + c.cancelPing = cancel + go c.loop(ctx) + + c.user = newUser(c.h) + c.roster = newRoster(c.h) + c.presences = newPresences(c.h) + c.vCard = newVCard(c.h) + c.priv = newPrivate(c.h) + c.blockList = newBlockList(c.h) + c.pubSub = newPubSub(c.h) + c.offline = newOffline(c.h) + c.room = newRoom(c.h) + + return c, nil +} + +func (c *pgSQLContainer) User() repository.User { return c.user } +func (c *pgSQLContainer) Roster() repository.Roster { return c.roster } +func (c *pgSQLContainer) Presences() repository.Presences { return c.presences } +func (c *pgSQLContainer) VCard() repository.VCard { return c.vCard } +func (c *pgSQLContainer) Private() repository.Private { return c.priv } +func (c *pgSQLContainer) BlockList() repository.BlockList { return c.blockList } +func (c *pgSQLContainer) PubSub() repository.PubSub { return c.pubSub } +func (c *pgSQLContainer) Offline() repository.Offline { return c.offline } +func (c *pgSQLContainer) Room() repository.Room { return c.room } +func (c *pgSQLContainer) Occupant() repository.Occupant { return c.occupant } + +func (c *pgSQLContainer) Close(ctx context.Context) error { + ch := make(chan bool) + c.doneCh <- ch + select { + case <-ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *pgSQLContainer) IsClusterCompatible() bool { return true } + +func (c *pgSQLContainer) loop(ctx context.Context) { + tick := time.NewTicker(pingInterval) + defer tick.Stop() + + for { + select { + case <-tick.C: + if err := c.ping(ctx); err != nil { + log.Error(err) + } + + case ch := <-c.doneCh: + if err := c.h.Close(); err != nil { + log.Error(err) + } + close(ch) + return + + case <-ctx.Done(): + return + } + } +} + +// ping sends a ping request to the server and outputs any error to log +func (c *pgSQLContainer) ping(ctx context.Context) error { + pingCtx, cancel := context.WithDeadline(ctx, time.Now().Add(pingTimeout)) + defer cancel() + + return c.h.PingContext(pingCtx) +} diff --git a/storage/pgsql/pgsql_test.go b/storage/pgsql/pgsql_test.go new file mode 100644 index 000000000..c3b153a24 --- /dev/null +++ b/storage/pgsql/pgsql_test.go @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import "errors" + +var ( + errGeneric = errors.New("pgsql: generic storage error") +) diff --git a/storage/pgsql/presences.go b/storage/pgsql/presences.go new file mode 100644 index 000000000..6951b0dbe --- /dev/null +++ b/storage/pgsql/presences.go @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "database/sql" + "encoding/json" + "strings" + + sq "github.com/Masterminds/squirrel" + capsmodel "github.com/ortuman/jackal/model/capabilities" + "github.com/ortuman/jackal/util/pool" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +type pgSQLPresences struct { + *pgSQLStorage + pool *pool.BufferPool +} + +func newPresences(db *sql.DB) *pgSQLPresences { + return &pgSQLPresences{ + pgSQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +func (s *pgSQLPresences) UpsertPresence(ctx context.Context, presence *xmpp.Presence, jid *jid.JID, allocationID string) (loaded bool, err error) { + buf := s.pool.Get() + defer s.pool.Put(buf) + if err := presence.ToXML(buf, true); err != nil { + return false, err + } + var node, ver string + if caps := presence.Capabilities(); caps != nil { + node = caps.Node + ver = caps.Ver + } + rawXML := buf.String() + + q := sq.Insert("presences"). + Columns("username", "domain", "resource", "presence", "node", "ver", "allocation_id"). + Values(jid.Node(), jid.Domain(), jid.Resource(), rawXML, node, ver, allocationID). + Suffix("ON CONFLICT (username, domain, resource) DO UPDATE SET presence = $4, node = $5, ver = $6, allocation_id = $7"). + Suffix("RETURNING CASE WHEN updated_at=created_at THEN true ELSE false END AS inserted") + + var inserted bool + err = q.RunWith(s.db).QueryRowContext(ctx).Scan(&inserted) + if err != nil { + return false, err + } + return inserted, nil +} + +func (s *pgSQLPresences) FetchPresence(ctx context.Context, jid *jid.JID) (*capsmodel.PresenceCaps, error) { + var rawXML, node, ver, featuresJSON string + + q := sq.Select("presence", "c.node", "c.ver", "c.features"). + From("presences AS p, capabilities AS c"). + Where(sq.And{ + sq.Eq{"username": jid.Node()}, + sq.Eq{"domain": jid.Domain()}, + sq.Eq{"resource": jid.Resource()}, + sq.Expr("p.node = c.node"), + sq.Expr("p.ver = c.ver"), + }). + RunWith(s.db) + + err := q.ScanContext(ctx, &rawXML, &node, &ver, &featuresJSON) + switch err { + case nil: + return scanPresenceAndCapabilties(rawXML, node, ver, featuresJSON) + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func (s *pgSQLPresences) FetchPresencesMatchingJID(ctx context.Context, jid *jid.JID) ([]capsmodel.PresenceCaps, error) { + var preds sq.And + if len(jid.Node()) > 0 { + preds = append(preds, sq.Eq{"username": jid.Node()}) + } + if len(jid.Domain()) > 0 { + preds = append(preds, sq.Eq{"domain": jid.Domain()}) + } + if len(jid.Resource()) > 0 { + preds = append(preds, sq.Eq{"resource": jid.Resource()}) + } + preds = append(preds, sq.Expr("p.node = c.node")) + preds = append(preds, sq.Expr("p.ver = c.ver")) + + q := sq.Select("presence", "c.node", "c.ver", "c.features"). + From("presences AS p, capabilities AS c"). + Where(preds). + RunWith(s.db) + + rows, err := q.QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var res []capsmodel.PresenceCaps + for rows.Next() { + var rawXML, node, ver, featuresJSON string + + if err := rows.Scan(&rawXML, &node, &ver, &featuresJSON); err != nil { + return nil, err + } + presenceCaps, err := scanPresenceAndCapabilties(rawXML, node, ver, featuresJSON) + if err != nil { + return nil, err + } + res = append(res, *presenceCaps) + } + return res, nil +} + +func (s *pgSQLPresences) DeletePresence(ctx context.Context, jid *jid.JID) error { + _, err := sq.Delete("presences"). + Where(sq.And{ + sq.Eq{"username": jid.Node()}, + sq.Eq{"domain": jid.Domain()}, + sq.Eq{"resource": jid.Resource()}, + }). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *pgSQLPresences) DeleteAllocationPresences(ctx context.Context, allocationID string) error { + _, err := sq.Delete("presences"). + Where(sq.Eq{"allocation_id": allocationID}). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *pgSQLPresences) ClearPresences(ctx context.Context) error { + _, err := sq.Delete("presences").RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *pgSQLPresences) UpsertCapabilities(ctx context.Context, caps *capsmodel.Capabilities) error { + b, err := json.Marshal(caps.Features) + if err != nil { + return err + } + _, err = sq.Insert("capabilities"). + Columns("node", "ver", "features"). + Values(caps.Node, caps.Ver, b). + Suffix("ON CONFLICT (node, ver) DO UPDATE SET features = $3"). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *pgSQLPresences) FetchCapabilities(ctx context.Context, node, ver string) (*capsmodel.Capabilities, error) { + var b string + err := sq.Select("features").From("capabilities"). + Where(sq.And{sq.Eq{"node": node}, sq.Eq{"ver": ver}}). + RunWith(s.db).QueryRowContext(ctx).Scan(&b) + switch err { + case nil: + var caps capsmodel.Capabilities + if err := json.NewDecoder(strings.NewReader(b)).Decode(&caps.Features); err != nil { + return nil, err + } + return &caps, nil + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func scanPresenceAndCapabilties(rawXML, node, ver, featuresJSON string) (*capsmodel.PresenceCaps, error) { + parser := xmpp.NewParser(strings.NewReader(rawXML), xmpp.DefaultMode, 0) + elem, err := parser.ParseElement() + if err != nil { + return nil, err + } + fromJID, _ := jid.NewWithString(elem.From(), true) + toJID, _ := jid.NewWithString(elem.To(), true) + + presence, err := xmpp.NewPresenceFromElement(elem, fromJID, toJID) + if err != nil { + return nil, err + } + var res capsmodel.PresenceCaps + + res.Presence = presence + if len(featuresJSON) > 0 { + res.Caps = &capsmodel.Capabilities{ + Node: node, + Ver: ver, + } + + if err := json.NewDecoder(strings.NewReader(featuresJSON)).Decode(&res.Caps.Features); err != nil { + return nil, err + } + } + return &res, nil +} diff --git a/storage/pgsql/presences_test.go b/storage/pgsql/presences_test.go new file mode 100644 index 000000000..64cd88f7b --- /dev/null +++ b/storage/pgsql/presences_test.go @@ -0,0 +1,183 @@ +package pgsql + +import ( + "context" + "encoding/json" + "testing" + + capsmodel "github.com/ortuman/jackal/model/capabilities" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/ortuman/jackal/util/pool" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/require" +) + +func TestPgSQLPresences_UpsertPresence(t *testing.T) { + var columns = []string{"inserted"} + + s, mock := newPresencesMock() + mock.ExpectQuery("INSERT INTO presences (.+) VALUES (.+) ON CONFLICT (.+) DO UPDATE SET (.+) RETURNING CASE WHEN updated_at=created_at THEN true ELSE false END AS inserted"). + WithArgs("ortuman", "jackal.im", "yard", ``, "", "", "alloc-1234"). + WillReturnRows(sqlmock.NewRows(columns).AddRow(true)) + + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + inserted, err := s.UpsertPresence(context.Background(), xmpp.NewPresence(j, j.ToBareJID(), xmpp.AvailableType), j, "alloc-1234") + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.True(t, inserted) +} + +func TestPgSQLPresences_FetchPresence(t *testing.T) { + var columns = []string{"presence", "c.node", "c.ver", "c.features"} + + s, mock := newPresencesMock() + mock.ExpectQuery("SELECT presence, c.node, c.ver, c.features FROM presences AS p, capabilities AS c WHERE \\(username = \\? AND domain = \\? AND resource = \\? AND p.node = c.node AND p.ver = c.ver\\)"). + WithArgs("ortuman", "jackal.im", "yard"). + WillReturnRows(sqlmock.NewRows(columns). + AddRow("", "http://jackal.im", "v1234", `["urn:xmpp:ping"]`)) + + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + presenceCaps, err := s.FetchPresence(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.NotNil(t, presenceCaps) + + require.Equal(t, "http://jackal.im", presenceCaps.Caps.Node) + require.Equal(t, "v1234", presenceCaps.Caps.Ver) + require.Len(t, presenceCaps.Caps.Features, 1) + require.Equal(t, "urn:xmpp:ping", presenceCaps.Caps.Features[0]) +} + +func TestPgSQLPresences_FetchPresencesMatchingJID(t *testing.T) { + var columns = []string{"presence", "c.node", "c.ver", "c.features"} + + s, mock := newPresencesMock() + mock.ExpectQuery("SELECT presence, c.node, c.ver, c.features FROM presences AS p, capabilities AS c WHERE \\(username = \\? AND domain = \\? AND resource = \\? AND p.node = c.node AND p.ver = c.ver\\)"). + WithArgs("ortuman", "jackal.im", "yard"). + WillReturnRows(sqlmock.NewRows(columns). + AddRow("", "http://jackal.im", "v1234", `["urn:xmpp:ping"]`). + AddRow("", "http://jackal.im", "v1234", `["urn:xmpp:ping"]`), + ) + + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + presenceCaps, err := s.FetchPresencesMatchingJID(context.Background(), j) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.NotNil(t, presenceCaps) + + require.Equal(t, "http://jackal.im", presenceCaps[0].Caps.Node) + require.Equal(t, "v1234", presenceCaps[0].Caps.Ver) + require.Len(t, presenceCaps[0].Caps.Features, 1) + require.Equal(t, "urn:xmpp:ping", presenceCaps[0].Caps.Features[0]) +} + +func TestPgSQLPresences_DeletePresence(t *testing.T) { + j, _ := jid.NewWithString("ortuman@jackal.im/yard", true) + + s, mock := newPresencesMock() + mock.ExpectExec("DELETE FROM presences WHERE \\(username = \\? AND domain = \\? AND resource = \\?\\)"). + WithArgs(j.Node(), j.Domain(), j.Resource()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := s.DeletePresence(context.Background(), j) + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestPgSQLPresences_DeleteAllocationPresence(t *testing.T) { + s, mock := newPresencesMock() + mock.ExpectExec("DELETE FROM presences WHERE allocation_id = ?"). + WithArgs("alloc-1234"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := s.DeleteAllocationPresences(context.Background(), "alloc-1234") + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestPgSQLPresences_ClearPresences(t *testing.T) { + s, mock := newPresencesMock() + mock.ExpectExec("DELETE FROM presences"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + err := s.ClearPresences(context.Background()) + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestPgSQLPresences_UpsertCapabilities(t *testing.T) { + features := []string{"jabber:iq:last"} + + b, _ := json.Marshal(&features) + + s, mock := newPresencesMock() + mock.ExpectExec("INSERT INTO capabilities (.+) VALUES (.+) ON CONFLICT (.+) DO UPDATE SET features = (.+)"). + WithArgs("n1", "1234A", b). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := s.UpsertCapabilities(context.Background(), &capsmodel.Capabilities{Node: "n1", Ver: "1234A", Features: features}) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + + // error case + s, mock = newPresencesMock() + mock.ExpectExec("INSERT INTO capabilities (.+) VALUES (.+) ON CONFLICT (.+) DO UPDATE SET features = (.+)"). + WithArgs("n1", "1234A", b). + WillReturnError(errGeneric) + + err = s.UpsertCapabilities(context.Background(), &capsmodel.Capabilities{Node: "n1", Ver: "1234A", Features: features}) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLPresences_FetchCapabilities(t *testing.T) { + s, mock := newPresencesMock() + rows := sqlmock.NewRows([]string{"features"}) + rows.AddRow(`["jabber:iq:last"]`) + + mock.ExpectQuery("SELECT features FROM capabilities WHERE \\(node = . AND ver = .\\)"). + WithArgs("n1", "1234A"). + WillReturnRows(rows) + + caps, err := s.FetchCapabilities(context.Background(), "n1", "1234A") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 1, len(caps.Features)) + require.Equal(t, "jabber:iq:last", caps.Features[0]) + + // error case + s, mock = newPresencesMock() + mock.ExpectQuery("SELECT features FROM capabilities WHERE \\(node = . AND ver = .\\)"). + WithArgs("n1", "1234A"). + WillReturnError(errGeneric) + + caps, err = s.FetchCapabilities(context.Background(), "n1", "1234A") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Nil(t, caps) +} + +func newPresencesMock() (*pgSQLPresences, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLPresences{ + pgSQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock +} diff --git a/storage/pgsql/private.go b/storage/pgsql/private.go index 4eed63451..db5f096b4 100644 --- a/storage/pgsql/private.go +++ b/storage/pgsql/private.go @@ -6,20 +6,36 @@ package pgsql import ( + "context" "database/sql" sq "github.com/Masterminds/squirrel" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" ) -// InsertOrUpdatePrivateXML inserts a new private element into storage, +type pgSQLPrivate struct { + *pgSQLStorage + pool *pool.BufferPool +} + +func newPrivate(db *sql.DB) *pgSQLPrivate { + return &pgSQLPrivate{ + pgSQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +// UpsertPrivateXML inserts a new private element into storage, // or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace string, username string) error { +func (s *pgSQLPrivate) UpsertPrivateXML(ctx context.Context, privateXML []xmpp.XElement, namespace string, username string) error { buf := s.pool.Get() defer s.pool.Put(buf) for _, elem := range privateXML { - elem.ToXML(buf, true) + if err := elem.ToXML(buf, true); err != nil { + return err + } } rawXML := buf.String() @@ -29,18 +45,18 @@ func (s *Storage) InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace Values(username, namespace, rawXML). Suffix("ON CONFLICT (username, namespace) DO UPDATE SET data = $4", rawXML) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } // FetchPrivateXML retrieves from storage a private element. -func (s *Storage) FetchPrivateXML(namespace string, username string) ([]xmpp.XElement, error) { +func (s *pgSQLPrivate) FetchPrivateXML(ctx context.Context, namespace string, username string) ([]xmpp.XElement, error) { q := sq.Select("data"). From("private_storage"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"namespace": namespace}}) var privateXML string - err := q.RunWith(s.db).QueryRow().Scan(&privateXML) + err := q.RunWith(s.db).QueryRowContext(ctx).Scan(&privateXML) switch err { case nil: buf := s.pool.Get() diff --git a/storage/pgsql/private_test.go b/storage/pgsql/private_test.go index 262f340ec..e0b371036 100644 --- a/storage/pgsql/private_test.go +++ b/storage/pgsql/private_test.go @@ -6,9 +6,11 @@ package pgsql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/stretchr/testify/require" ) @@ -17,21 +19,21 @@ func TestInsertPrivateXML(t *testing.T) { private := xmpp.NewElementNamespace("exodus", "exodus:ns") rawXML := private.String() - s, mock := NewMock() + s, mock := newPrivateMock() mock.ExpectExec("INSERT INTO private_storage (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs("ortuman", "exodus:ns", rawXML, rawXML). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdatePrivateXML([]xmpp.XElement{private}, "exodus:ns", "ortuman") + err := s.UpsertPrivateXML(context.Background(), []xmpp.XElement{private}, "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectExec("INSERT INTO private_storage (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs("ortuman", "exodus:ns", rawXML, rawXML). WillReturnError(errGeneric) - err = s.InsertOrUpdatePrivateXML([]xmpp.XElement{private}, "exodus:ns", "ortuman") + err = s.UpsertPrivateXML(context.Background(), []xmpp.XElement{private}, "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) } @@ -39,53 +41,61 @@ func TestInsertPrivateXML(t *testing.T) { func TestFetchPrivateXML(t *testing.T) { var privateColumns = []string{"data"} - s, mock := NewMock() + s, mock := newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns).AddRow("")) - elems, err := s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err := s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 1, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns).AddRow("")) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) require.Equal(t, 0, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns).AddRow("")) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 0, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnRows(sqlmock.NewRows(privateColumns)) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 0, len(elems)) - s, mock = NewMock() + s, mock = newPrivateMock() mock.ExpectQuery("SELECT (.+) FROM private_storage (.+)"). WithArgs("ortuman", "exodus:ns"). WillReturnError(errGeneric) - elems, err = s.FetchPrivateXML("exodus:ns", "ortuman") + elems, err = s.FetchPrivateXML(context.Background(), "exodus:ns", "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) require.Equal(t, 0, len(elems)) } + +func newPrivateMock() (*pgSQLPrivate, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLPrivate{ + pgSQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock +} diff --git a/storage/pgsql/pubsub.go b/storage/pgsql/pubsub.go new file mode 100644 index 000000000..ed24cfa5e --- /dev/null +++ b/storage/pgsql/pubsub.go @@ -0,0 +1,510 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "database/sql" + "strings" + + sq "github.com/Masterminds/squirrel" + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + "github.com/ortuman/jackal/xmpp" +) + +type pgSQLPubSub struct { + *pgSQLStorage +} + +func newPubSub(db *sql.DB) *pgSQLPubSub { + return &pgSQLPubSub{ + pgSQLStorage: newStorage(db), + } +} + +func (s *pgSQLPubSub) FetchHosts(ctx context.Context) ([]string, error) { + rows, err := sq.Select("DISTINCT(host)"). + From("pubsub_nodes"). + RunWith(s.db). + QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var hosts []string + for rows.Next() { + var host string + if err := rows.Scan(&host); err != nil { + return nil, err + } + hosts = append(hosts, host) + } + return hosts, nil +} + +func (s *pgSQLPubSub) UpsertNode(ctx context.Context, node *pubsubmodel.Node) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // if not existing, insert new node + _, err := sq.Insert("pubsub_nodes"). + Columns("host", "name", "updated_at", "created_at"). + Suffix("ON CONFLICT (host, name) DO NOTHING"). + Values(node.Host, node.Name, nowExpr, nowExpr). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // fetch node identifier + var nodeIdentifier string + + err = sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": node.Host}, sq.Eq{"name": node.Name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + if err != nil { + return err + } + + // delete previous node options + _, err = sq.Delete("pubsub_node_options"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // insert new option set + optionSetMap, err := node.Options.Map() + if err != nil { + return err + } + for name, value := range optionSetMap { + _, err = sq.Insert("pubsub_node_options"). + Columns("node_id", "name", "value"). + Values(nodeIdentifier, name, value). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + return nil + }) +} + +func (s *pgSQLPubSub) FetchNode(ctx context.Context, host, name string) (*pubsubmodel.Node, error) { + opts, err := s.fetchPubSubNodeOptions(ctx, host, name) + if err != nil { + return nil, err + } + if opts == nil { + return nil, nil // not found + } + return &pubsubmodel.Node{ + Host: host, + Name: name, + Options: *opts, + }, nil +} + +func (s *pgSQLPubSub) FetchNodes(ctx context.Context, host string) ([]pubsubmodel.Node, error) { + rows, err := sq.Select("name"). + From("pubsub_nodes"). + Where(sq.Eq{"host": host}). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var nodes []pubsubmodel.Node + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + var node = pubsubmodel.Node{Host: host, Name: name} + opts, err := s.fetchPubSubNodeOptions(ctx, host, name) + if err != nil { + return nil, err + } + if opts != nil { + node.Options = *opts + } + nodes = append(nodes, node) + } + return nodes, nil +} + +func (s *pgSQLPubSub) FetchSubscribedNodes(ctx context.Context, jid string) ([]pubsubmodel.Node, error) { + rows, err := sq.Select("host", "name"). + From("pubsub_nodes"). + Where(sq.Expr("id IN (SELECT DISTINCT(node_id) FROM pubsub_subscriptions WHERE jid = $1 AND subscription = $2)", jid, pubsubmodel.Subscribed)). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var nodes []pubsubmodel.Node + for rows.Next() { + var host, name string + if err := rows.Scan(&host, &name); err != nil { + return nil, err + } + var node = pubsubmodel.Node{Host: host, Name: name} + opts, err := s.fetchPubSubNodeOptions(ctx, host, name) + if err != nil { + return nil, err + } + if opts != nil { + node.Options = *opts + } + nodes = append(nodes, node) + } + return nodes, nil +} + +func (s *pgSQLPubSub) DeleteNode(ctx context.Context, host, name string) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + // delete node + _, err = sq.Delete("pubsub_nodes"). + Where(sq.Eq{"id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete options + _, err = sq.Delete("pubsub_node_options"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete items + _, err = sq.Delete("pubsub_items"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete affiliations + _, err = sq.Delete("pubsub_affiliations"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + // delete subscriptions + _, err = sq.Delete("pubsub_subscriptions"). + Where(sq.Eq{"node_id": nodeIdentifier}). + RunWith(tx).ExecContext(ctx) + return err + }) +} + +func (s *pgSQLPubSub) UpsertNodeItem(ctx context.Context, item *pubsubmodel.Item, host, name string, maxNodeItems int) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + + // upsert new item + rawPayload := item.Payload.String() + + _, err = sq.Insert("pubsub_items"). + Columns("node_id", "item_id", "payload", "publisher"). + Values(nodeIdentifier, item.ID, rawPayload, item.Publisher). + Suffix("ON CONFLICT (node_id, item_id) DO UPDATE SET payload = $5, publisher = $6", rawPayload, item.Publisher). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // check if maximum item count was reached and delete oldest one + _, err = sq.Delete("pubsub_items"). + Where("item_id IN (SELECT item_id FROM pubsub_items WHERE node_id = $1 ORDER BY created_at DESC OFFSET $2)", nodeIdentifier, maxNodeItems). + RunWith(tx).ExecContext(ctx) + return err + }) +} + +func (s *pgSQLPubSub) FetchNodeItems(ctx context.Context, host, name string) ([]pubsubmodel.Item, error) { + rows, err := sq.Select("item_id", "publisher", "payload"). + From("pubsub_items"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", host, name). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeItems(rows) +} + +func (s *pgSQLPubSub) FetchNodeItemsWithIDs(ctx context.Context, host, name string, identifiers []string) ([]pubsubmodel.Item, error) { + rows, err := sq.Select("item_id", "publisher", "payload"). + From("pubsub_items"). + Where(sq.And{sq.Expr("node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", host, name), sq.Eq{"id": identifiers}}). + OrderBy("created_at"). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeItems(rows) +} + +func (s *pgSQLPubSub) FetchNodeLastItem(ctx context.Context, host, name string) (*pubsubmodel.Item, error) { + row := sq.Select("item_id", "publisher", "payload"). + From("pubsub_items"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", host, name). + OrderBy("created_at DESC"). + Limit(1). + RunWith(s.db).QueryRowContext(ctx) + + item, err := scanPubSubNodeItem(row) + switch err { + case nil: + return item, nil + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func (s *pgSQLPubSub) UpsertNodeAffiliation(ctx context.Context, affiliation *pubsubmodel.Affiliation, host, name string) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + + // upsert affiliation + _, err = sq.Insert("pubsub_affiliations"). + Columns("node_id", "jid", "affiliation"). + Values(nodeIdentifier, affiliation.JID, affiliation.Affiliation). + Suffix("ON CONFLICT (node_id, jid) DO UPDATE SET affiliation = $4", affiliation.Affiliation). + RunWith(tx).ExecContext(ctx) + return err + }) +} + +func (s *pgSQLPubSub) FetchNodeAffiliation(ctx context.Context, host, name, jid string) (*pubsubmodel.Affiliation, error) { + var aff pubsubmodel.Affiliation + + row := sq.Select("jid", "affiliation"). + From("pubsub_affiliations"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2) AND jid = $3", host, name, jid). + RunWith(s.db).QueryRowContext(ctx) + err := row.Scan(&aff.JID, &aff.Affiliation) + switch err { + case nil: + return &aff, nil + case sql.ErrNoRows: + return nil, nil + default: + return nil, err + } +} + +func (s *pgSQLPubSub) FetchNodeAffiliations(ctx context.Context, host, name string) ([]pubsubmodel.Affiliation, error) { + rows, err := sq.Select("jid", "affiliation"). + From("pubsub_affiliations"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", host, name). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeAffiliations(rows) +} + +func (s *pgSQLPubSub) DeleteNodeAffiliation(ctx context.Context, jid, host, name string) error { + _, err := sq.Delete("pubsub_affiliations"). + Where("jid = $1 AND node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", jid, host, name). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *pgSQLPubSub) UpsertNodeSubscription(ctx context.Context, subscription *pubsubmodel.Subscription, host, name string) error { + return s.inTransaction(ctx, func(tx *sql.Tx) error { + // fetch node identifier + var nodeIdentifier string + + err := sq.Select("id"). + From("pubsub_nodes"). + Where(sq.And{sq.Eq{"host": host}, sq.Eq{"name": name}}). + RunWith(tx).QueryRowContext(ctx).Scan(&nodeIdentifier) + switch err { + case nil: + break + case sql.ErrNoRows: + return nil + default: + return err + } + + // upsert subscription + _, err = sq.Insert("pubsub_subscriptions"). + Columns("node_id", "subid", "jid", "subscription", "updated_at", "created_at"). + Values(nodeIdentifier, subscription.SubID, subscription.JID, subscription.Subscription, nowExpr, nowExpr). + Suffix("ON CONFLICT (node_id, jid) DO UPDATE SET subid = $5, subscription = $6", subscription.SubID, subscription.Subscription). + RunWith(tx).ExecContext(ctx) + return err + }) +} + +func (s *pgSQLPubSub) FetchNodeSubscriptions(ctx context.Context, host, name string) ([]pubsubmodel.Subscription, error) { + rows, err := sq.Select("subid", "jid", "subscription"). + From("pubsub_subscriptions"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", host, name). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + return scanPubSubNodeSubscriptions(rows) +} + +func (s *pgSQLPubSub) DeleteNodeSubscription(ctx context.Context, jid, host, name string) error { + _, err := sq.Delete("pubsub_subscriptions"). + Where("jid = $1 AND node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", jid, host, name). + RunWith(s.db).ExecContext(ctx) + return err +} + +func (s *pgSQLPubSub) fetchPubSubNodeOptions(ctx context.Context, host, name string) (*pubsubmodel.Options, error) { + rows, err := sq.Select("name", "value"). + From("pubsub_node_options"). + Where("node_id = (SELECT id FROM pubsub_nodes WHERE host = $1 AND name = $2)", host, name). + OrderBy("created_at"). + RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var optMap = make(map[string]string) + for rows.Next() { + var opt, value string + if err := rows.Scan(&opt, &value); err != nil { + return nil, err + } + optMap[opt] = value + } + if len(optMap) == 0 { + return nil, nil // node does not exist + } + opts, err := pubsubmodel.NewOptionsFromMap(optMap) + if err != nil { + return nil, err + } + return opts, nil +} + +func scanPubSubNodeAffiliations(scanner rowsScanner) ([]pubsubmodel.Affiliation, error) { + var affiliations []pubsubmodel.Affiliation + + for scanner.Next() { + var affiliation pubsubmodel.Affiliation + if err := scanner.Scan(&affiliation.JID, &affiliation.Affiliation); err != nil { + return nil, err + } + affiliations = append(affiliations, affiliation) + } + return affiliations, nil +} + +func scanPubSubNodeSubscriptions(scanner rowsScanner) ([]pubsubmodel.Subscription, error) { + var subscriptions []pubsubmodel.Subscription + + for scanner.Next() { + var subscription pubsubmodel.Subscription + if err := scanner.Scan(&subscription.SubID, &subscription.JID, &subscription.Subscription); err != nil { + return nil, err + } + subscriptions = append(subscriptions, subscription) + } + return subscriptions, nil +} + +func scanPubSubNodeItems(scanner rowsScanner) ([]pubsubmodel.Item, error) { + var items []pubsubmodel.Item + var err error + + for scanner.Next() { + var payload string + var item pubsubmodel.Item + if err := scanner.Scan(&item.ID, &item.Publisher, &payload); err != nil { + return nil, err + } + parser := xmpp.NewParser(strings.NewReader(payload), xmpp.DefaultMode, 0) + item.Payload, err = parser.ParseElement() + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, nil +} + +func scanPubSubNodeItem(scanner rowScanner) (*pubsubmodel.Item, error) { + var payload string + var item pubsubmodel.Item + var err error + + if err = scanner.Scan(&item.ID, &item.Publisher, &payload); err != nil { + return nil, err + } + parser := xmpp.NewParser(strings.NewReader(payload), xmpp.DefaultMode, 0) + item.Payload, err = parser.ParseElement() + if err != nil { + return nil, err + } + return &item, nil +} diff --git a/storage/pgsql/pubsub_test.go b/storage/pgsql/pubsub_test.go new file mode 100644 index 000000000..5e38b3509 --- /dev/null +++ b/storage/pgsql/pubsub_test.go @@ -0,0 +1,474 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/google/uuid" + pubsubmodel "github.com/ortuman/jackal/model/pubsub" + "github.com/ortuman/jackal/xmpp" + "github.com/stretchr/testify/require" +) + +func TestPgSQLFetchPubSubHosts(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"host"}) + rows.AddRow("ortuman@jackal.im") + rows.AddRow("noelia@jackal.im") + + mock.ExpectQuery("SELECT DISTINCT\\(host\\) FROM pubsub_nodes"). + WillReturnRows(rows) + + hosts, err := s.FetchHosts(context.Background()) + require.Nil(t, err) + require.NotNil(t, hosts) + require.Equal(t, "ortuman@jackal.im", hosts[0]) + require.Equal(t, "noelia@jackal.im", hosts[1]) + + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT DISTINCT\\(host\\) FROM pubsub_nodes"). + WillReturnError(errGeneric) + + hosts, err = s.FetchHosts(context.Background()) + require.Nil(t, hosts) + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLUpsertPubSubNode(t *testing.T) { + s, mock := newPubSubMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO pubsub_nodes (.+) ON CONFLICT (.+) DO NOTHING"). + WithArgs("host", "name"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("host", "name"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("DELETE FROM pubsub_node_options WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + opts := pubsubmodel.Options{} + + optMap, _ := opts.Map() + for i := 0; i < len(optMap); i++ { + mock.ExpectExec("INSERT INTO pubsub_node_options (.+)"). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 1)) + } + mock.ExpectCommit() + + node := pubsubmodel.Node{Host: "host", Name: "name", Options: opts} + err := s.UpsertNode(context.Background(), &node) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + + // error case + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errGeneric) + + _, err = s.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLFetchPubSubNode(t *testing.T) { + var cols = []string{"name", "value"} + + s, mock := newPubSubMock() + rows := sqlmock.NewRows(cols) + rows.AddRow("pubsub#access_model", "presence") + rows.AddRow("pubsub#publish_model", "publishers") + rows.AddRow("pubsub#send_last_published_item", "on_sub_and_presence") + + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + node, err := s.FetchNode(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.NotNil(t, node) + require.Equal(t, node.Options.AccessModel, pubsubmodel.Presence) + require.Equal(t, node.Options.SendLastPublishedItem, pubsubmodel.OnSubAndPresence) +} + +func TestPgSQLFetchPubSubNodes(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"name"}) + rows.AddRow("princely_musings_1") + rows.AddRow("princely_musings_2") + + mock.ExpectQuery("SELECT name FROM pubsub_nodes WHERE host = (.+)"). + WithArgs("ortuman@jackal.im"). + WillReturnRows(rows) + + var cols = []string{"name", "value"} + + rows = sqlmock.NewRows(cols) + rows.AddRow("pubsub#access_model", "presence") + rows.AddRow("pubsub#publish_model", "publishers") + rows.AddRow("pubsub#send_last_published_item", "on_sub_and_presence") + + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_1"). + WillReturnRows(rows) + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_2"). + WillReturnRows(rows) + + nodes, err := s.FetchNodes(context.Background(), "ortuman@jackal.im") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.NotNil(t, nodes) + require.Len(t, nodes, 2) + require.Equal(t, "princely_musings_1", nodes[0].Name) + require.Equal(t, "princely_musings_2", nodes[1].Name) +} + +func TestPgSQLFetchPubSubSubscribedNodes(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"host", "name"}) + rows.AddRow("ortuman@jackal.im", "princely_musings_1") + rows.AddRow("ortuman@jackal.im", "princely_musings_2") + + mock.ExpectQuery("SELECT host, name FROM pubsub_nodes WHERE id IN (.+)"). + WithArgs("ortuman@jackal.im", pubsubmodel.Subscribed). + WillReturnRows(rows) + + var cols = []string{"name", "value"} + + rows = sqlmock.NewRows(cols) + rows.AddRow("pubsub#access_model", "presence") + rows.AddRow("pubsub#publish_model", "publishers") + rows.AddRow("pubsub#send_last_published_item", "on_sub_and_presence") + + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_1"). + WillReturnRows(rows) + mock.ExpectQuery("SELECT name, value FROM pubsub_node_options WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings_2"). + WillReturnRows(rows) + + nodes, err := s.FetchSubscribedNodes(context.Background(), "ortuman@jackal.im") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.NotNil(t, nodes) + require.Len(t, nodes, 2) + require.Equal(t, "princely_musings_1", nodes[0].Name) + require.Equal(t, "princely_musings_2", nodes[1].Name) +} + +func TestPgSQLDeletePubSubNode(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("DELETE FROM pubsub_nodes WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_node_options WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_items WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_affiliations WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("1"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.DeleteNode(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) +} + +func TestPgSQLUpsertPubSubNodeItem(t *testing.T) { + payload := xmpp.NewIQType(uuid.New().String(), xmpp.GetType) + + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("INSERT INTO pubsub_items (.+) ON CONFLICT (.+) DO UPDATE SET payload = (.+), publisher = (.+)"). + WithArgs("1", "abc1234", payload.String(), "ortuman@jackal.im", payload.String(), "ortuman@jackal.im"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + mock.ExpectExec("DELETE FROM pubsub_items WHERE item_id IN (.+)"). + WithArgs("1", int64(1)). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.UpsertNodeItem(context.Background(), &pubsubmodel.Item{ + ID: "abc1234", + Publisher: "ortuman@jackal.im", + Payload: payload, + }, "ortuman@jackal.im", "princely_musings", 1) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) +} + +func TestPgSQLFetchPubSubNodeItems(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"item_id", "publisher", "payload"}) + rows.AddRow("1234", "ortuman@jackal.im", "") + rows.AddRow("5678", "noelia@jackal.im", "") + + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE node_id = (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + items, err := s.FetchNodeItems(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(items)) + require.Equal(t, "1234", items[0].ID) + require.Equal(t, "5678", items[1].ID) + + // error case + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE node_id = (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errGeneric) + + _, err = s.FetchNodeItems(context.Background(), "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLFetchPubSubNodeItemsWithID(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"item_id", "publisher", "payload"}) + rows.AddRow("1234", "ortuman@jackal.im", "") + rows.AddRow("5678", "noelia@jackal.im", "") + + identifiers := []string{"1234", "5678"} + + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE (.+ IN (.+)) ORDER BY created_at"). + WithArgs("ortuman@jackal.im", "princely_musings", "1234", "5678"). + WillReturnRows(rows) + + items, err := s.FetchNodeItemsWithIDs(context.Background(), "ortuman@jackal.im", "princely_musings", identifiers) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(items)) + require.Equal(t, "1234", items[0].ID) + require.Equal(t, "5678", items[1].ID) + + // error case + s, mock = newPubSubMock() + mock.ExpectQuery("SELECT item_id, publisher, payload FROM pubsub_items WHERE (.+ IN (.+)) ORDER BY created_at"). + WithArgs("ortuman@jackal.im", "princely_musings", "1234", "5678"). + WillReturnError(errGeneric) + + _, err = s.FetchNodeItemsWithIDs(context.Background(), "ortuman@jackal.im", "princely_musings", identifiers) + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLUpsertPubSubNodeAffiliation(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("INSERT INTO pubsub_affiliations (.+) VALUES (.+) ON CONFLICT (.+) DO UPDATE SET affiliation = (.+)"). + WithArgs("1", "ortuman@jackal.im", "owner", "owner"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.UpsertNodeAffiliation(context.Background(), &pubsubmodel.Affiliation{ + JID: "ortuman@jackal.im", + Affiliation: "owner", + }, "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) +} + +func TestPgSQLFetchPubSubNodeAffiliations(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"jid", "affiliation"}) + rows.AddRow("ortuman@jackal.im", "owner") + rows.AddRow("noelia@jackal.im", "publisher") + + mock.ExpectQuery("SELECT jid, affiliation FROM pubsub_affiliations WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + affiliations, err := s.FetchNodeAffiliations(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(affiliations)) + + // error case + mock.ExpectQuery("SELECT jid, affiliation FROM pubsub_affiliations WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errGeneric) + + affiliations, err = s.FetchNodeAffiliations(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLDeletePubSubNodeAffiliation(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectExec("DELETE FROM pubsub_affiliations WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := s.DeleteNodeAffiliation(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + + // error case + s, mock = newPubSubMock() + mock.ExpectExec("DELETE FROM pubsub_affiliations WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnError(errGeneric) + + err = s.DeleteNodeAffiliation(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLUpsertPubSubNodeSubscription(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectBegin() + mock.ExpectQuery("SELECT id FROM pubsub_nodes WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + + mock.ExpectExec("INSERT INTO pubsub_subscriptions (.+) VALUES (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs("1", "1234", "ortuman@jackal.im", "subscribed", "1234", "subscribed"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.UpsertNodeSubscription(context.Background(), &pubsubmodel.Subscription{ + SubID: "1234", + JID: "ortuman@jackal.im", + Subscription: "subscribed", + }, "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) +} + +func TestPgSQLFetchPubSubNodeSubscriptions(t *testing.T) { + s, mock := newPubSubMock() + rows := sqlmock.NewRows([]string{"subid", "jid", "subscription"}) + rows.AddRow("1234", "ortuman@jackal.im", "subscribed") + rows.AddRow("5678", "noelia@jackal.im", "unsubscribed") + + mock.ExpectQuery("SELECT subid, jid, subscription FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnRows(rows) + + subscriptions, err := s.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + require.Equal(t, 2, len(subscriptions)) + + // error case + mock.ExpectQuery("SELECT subid, jid, subscription FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("ortuman@jackal.im", "princely_musings"). + WillReturnError(errGeneric) + + subscriptions, err = s.FetchNodeSubscriptions(context.Background(), "ortuman@jackal.im", "princely_musings") + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func TestPgSQLDeletePubSubNodeSubscription(t *testing.T) { + s, mock := newPubSubMock() + + mock.ExpectExec("DELETE FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := s.DeleteNodeSubscription(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, err) + + // error case + s, mock = newPubSubMock() + mock.ExpectExec("DELETE FROM pubsub_subscriptions WHERE (.+)"). + WithArgs("noeliac@jackal.im", "ortuman@jackal.im", "princely_musings"). + WillReturnError(errGeneric) + + err = s.DeleteNodeSubscription(context.Background(), "noeliac@jackal.im", "ortuman@jackal.im", "princely_musings") + + require.Nil(t, mock.ExpectationsWereMet()) + + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func newPubSubMock() (*pgSQLPubSub, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLPubSub{ + pgSQLStorage: s, + }, sqlMock +} diff --git a/storage/pgsql/room.go b/storage/pgsql/room.go new file mode 100644 index 000000000..209f9e1ff --- /dev/null +++ b/storage/pgsql/room.go @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "database/sql" + + sq "github.com/Masterminds/squirrel" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +type pgSQLRoom struct { + *pgSQLStorage +} + +func newRoom(db *sql.DB) *pgSQLRoom { + return &pgSQLRoom{ + pgSQLStorage: newStorage(db), + } +} + +func (r *pgSQLRoom) UpsertRoom(ctx context.Context, room *mucmodel.Room) error { + return r.inTransaction(ctx, func(tx *sql.Tx) error { + // rooms table + columns := []string{"room_jid", "name", "description", "subject", "language", "locked", + "occupants_online"} + values := []interface{}{room.RoomJID.String(), room.Name, room.Desc, room.Subject, + room.Language, room.Locked, room.GetOccupantsOnlineCount()} + q := sq.Insert("rooms"). + Columns(columns...). + Values(values...). + Suffix("ON CONFLICT (room_jid) DO UPDATE SET name = $2, description = $3, subject = $4" + ", language = $5, locked = $6, occupants_online = $7") + _, err := q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // rooms_config table + rc := room.Config + columns = []string{"room_jid", "public", "persistent", "pwd_protected", "password", "open", + "moderated", "allow_invites", "max_occupants", "allow_subj_change", "non_anonymous", + "can_send_pm", "can_get_member_list"} + values = []interface{}{room.RoomJID.String(), rc.Public, rc.Persistent, rc.PwdProtected, + rc.Password, rc.Open, rc.Moderated, rc.AllowInvites, rc.MaxOccCnt, rc.AllowSubjChange, + rc.NonAnonymous, rc.WhoCanSendPM(), rc.WhoCanGetMemberList()} + q = sq.Insert("rooms_config"). + Columns(columns...). + Values(values...). + Suffix("ON CONFLICT (room_jid) DO UPDATE SET public = $2, persistent = $3, pwd_protected = $4, " + + "password = $5, open = $6, moderated = $7, allow_invites = $8, max_occupants = $9, " + + "allow_subj_change = $10, non_anonymous = $11, can_send_pm = $12, can_get_member_list = $13") + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + + // rooms_invites table + columns = []string{"room_jid", "user_jid"} + for _, u := range room.GetAllInvitedUsers() { + values = []interface{}{room.RoomJID.String(), u} + q = sq.Insert("rooms_invites"). + Columns(columns...). + Values(values...). + Suffix("ON CONFLICT (room_jid) DO UPDATE SET user_jid = $2") + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + + // rooms_users table + columns = []string{"room_jid", "user_jid", "occupant_jid"} + for _, u := range room.GetAllUserJIDs() { + occJID, _ := room.GetOccupantJID(&u) + values = []interface{}{room.RoomJID.String(), u.String(), occJID.String()} + q = sq.Insert("rooms_users"). + Columns(columns...). + Values(values...). + Suffix("ON CONFLICT (room_jid) DO UPDATE SET occupant_jid = $3") + _, err = q.RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + } + return nil + }) +} + +func (r *pgSQLRoom) FetchRoom(ctx context.Context, roomJID *jid.JID) (*mucmodel.Room, error) { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + room, err := fetchRoomData(ctx, tx, roomJID) + switch err { + case nil: + case sql.ErrNoRows: + _ = tx.Commit() + return nil, nil + default: + _ = tx.Rollback() + return nil, err + } + + err = fetchRoomConfig(ctx, tx, room, roomJID) + switch err { + case nil: + case sql.ErrNoRows: + _ = tx.Commit() + return nil, nil + default: + _ = tx.Rollback() + return nil, err + } + + err = fetchRoomUsers(ctx, tx, room, roomJID) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + err = fetchRoomInvites(ctx, tx, room, roomJID) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + return room, nil +} + +func fetchRoomData(ctx context.Context, tx *sql.Tx, roomJID *jid.JID) (*mucmodel.Room, + error) { + room := &mucmodel.Room{} + // fetch room data + q := sq.Select("room_jid", "name", "description", "subject", "language", "locked", + "occupants_online"). + From("rooms"). + Where(sq.Eq{"room_jid": roomJID.String()}) + var onlineCnt int + var roomJIDStr string + err := q.RunWith(tx). + QueryRowContext(ctx). + Scan(&roomJIDStr, &room.Name, &room.Desc, &room.Subject, &room.Language, &room.Locked, + &onlineCnt) + switch err { + case nil: + rJID, err := jid.NewWithString(roomJIDStr, false) + if err != nil { + return nil, err + } + room.RoomJID = rJID + room.SetOccupantsOnlineCount(onlineCnt) + default: + return nil, err + } + return room, nil +} + +func fetchRoomConfig(ctx context.Context, tx *sql.Tx, room *mucmodel.Room, + roomJID *jid.JID) error { + rc := &mucmodel.RoomConfig{} + q := sq.Select("room_jid", "public", "persistent", "pwd_protected", "password", "open", + "moderated", "allow_invites", "max_occupants", "allow_subj_change", "non_anonymous", + "can_send_pm", "can_get_member_list"). + From("rooms_config"). + Where(sq.Eq{"room_jid": roomJID.String()}) + var dummy, sendPM, membList string + err := q.RunWith(tx). + QueryRowContext(ctx). + Scan(&dummy, &rc.Public, &rc.Persistent, &rc.PwdProtected, &rc.Password, &rc.Open, + &rc.Moderated, &rc.AllowInvites, &rc.MaxOccCnt, &rc.AllowSubjChange, &rc.NonAnonymous, + &sendPM, &membList) + switch err { + case nil: + err = rc.SetWhoCanSendPM(sendPM) + if err != nil { + return err + } + err = rc.SetWhoCanGetMemberList(membList) + if err != nil { + return err + } + default: + return err + } + room.Config = rc + return nil +} + +func fetchRoomUsers(ctx context.Context, tx *sql.Tx, room *mucmodel.Room, + roomJID *jid.JID) error { + res, err := sq.Select("room_jid", "user_jid", "occupant_jid"). + From("rooms_users"). + Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).QueryContext(ctx) + if err != nil { + return err + } + for res.Next() { + var dummy, uJIDStr, oJIDStr string + if err := res.Scan(&dummy, &uJIDStr, &oJIDStr); err != nil { + return err + } + uJID, err := jid.NewWithString(uJIDStr, false) + if err != nil { + return err + } + oJID, err := jid.NewWithString(oJIDStr, false) + if err != nil { + return err + } + err = room.MapUserToOccupantJID(uJID, oJID) + if err != nil { + return err + } + } + return nil +} + +func fetchRoomInvites(ctx context.Context, tx *sql.Tx, room *mucmodel.Room, + roomJID *jid.JID) error { + resInv, err := sq.Select("room_jid", "user_jid"). + From("rooms_invites"). + Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).QueryContext(ctx) + if err != nil { + return err + } + for resInv.Next() { + var dummy, uJIDStr string + if err := resInv.Scan(&dummy, &uJIDStr); err != nil { + return err + } + uJID, err := jid.NewWithString(uJIDStr, false) + if err != nil { + return err + } + err = room.InviteUser(uJID) + if err != nil { + return err + } + } + return nil +} + +func (r *pgSQLRoom) DeleteRoom(ctx context.Context, roomJID *jid.JID) error { + return r.inTransaction(ctx, func(tx *sql.Tx) error { + _, err := sq.Delete("rooms").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("rooms_config").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("rooms_users").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + _, err = sq.Delete("rooms_invites").Where(sq.Eq{"room_jid": roomJID.String()}). + RunWith(tx).ExecContext(ctx) + if err != nil { + return err + } + return nil + }) +} + +func (r *pgSQLRoom) RoomExists(ctx context.Context, roomJID *jid.JID) (bool, error) { + q := sq.Select("COUNT(*)"). + From("rooms"). + Where(sq.Eq{"room_jid": roomJID.String()}) + + var count int + err := q.RunWith(r.db).QueryRowContext(ctx).Scan(&count) + switch err { + case nil: + return count > 0, nil + default: + return false, err + } +} diff --git a/storage/pgsql/room_test.go b/storage/pgsql/room_test.go new file mode 100644 index 000000000..7394e9e9a --- /dev/null +++ b/storage/pgsql/room_test.go @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPgSQLStorageInsertRoom(t *testing.T) { + room := getTestRoom() + s, mock := newRoomMock() + rc := room.Config + userJID := room.GetAllUserJIDs()[0] + occJID, _ := room.GetOccupantJID(&userJID) + invitedUser := room.GetAllInvitedUsers()[0] + + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO rooms (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(room.RoomJID.String(), room.Name, room.Desc, room.Subject, room.Language, + room.Locked, room.GetOccupantsOnlineCount()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO rooms_config (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(room.RoomJID.String(), rc.Public, rc.Persistent, rc.PwdProtected, + rc.Password, rc.Open, rc.Moderated, rc.AllowInvites, rc.MaxOccCnt, rc.AllowSubjChange, + rc.NonAnonymous, rc.WhoCanSendPM(), rc.WhoCanGetMemberList()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO rooms_invites (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(room.RoomJID.String(), invitedUser). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT INTO rooms_users (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(room.RoomJID.String(), userJID.String(), occJID.String()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := s.UpsertRoom(context.Background(), room) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO rooms (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). + WithArgs(room.RoomJID.String(), room.Name, room.Desc, room.Subject, room.Language, + room.Locked, room.GetOccupantsOnlineCount()). + WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.UpsertRoom(context.Background(), room) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, err, errMocked) +} + +func TestPgSQLStorageFetchRoom(t *testing.T) { + room := getTestRoom() + rc := room.Config + s, mock := newRoomMock() + roomColumns := []string{"room_jid", "name", "description", "subject", "language", "locked", + "occupants_online"} + rcColumns := []string{"room_jid", "public", "persistent", "pwd_protected", "password", "open", + "moderated", "allow_invites", "max_occupants", "allow_subj_change", "non_anonymous", + "can_send_pm", "can_get_member_list"} + usersColumns := []string{"room_jid", "user_jid", "occupant_jid"} + invitesColumns := []string{"room_jid", "user_jid"} + + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(roomColumns)) + mock.ExpectCommit() + + r, _ := s.FetchRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, r) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(roomColumns). + AddRow(room.RoomJID.String(), room.Name, room.Desc, room.Subject, room.Language, + room.Locked, room.GetOccupantsOnlineCount())) + mock.ExpectQuery("SELECT (.+) FROM rooms_config (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(rcColumns). + AddRow(room.RoomJID.String(), rc.Public, rc.Persistent, rc.PwdProtected, rc.Password, + rc.Open, rc.Moderated, rc.AllowInvites, rc.MaxOccCnt, rc.AllowSubjChange, + rc.NonAnonymous, rc.WhoCanSendPM(), rc.WhoCanGetMemberList())) + mock.ExpectQuery("SELECT (.+) FROM rooms_users (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(usersColumns). + AddRow(room.RoomJID.String(), room.GetAllUserJIDs()[0].String(), + room.GetAllOccupantJIDs()[0].String())) + mock.ExpectQuery("SELECT (.+) FROM rooms_invites (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(invitesColumns). + AddRow(room.RoomJID.String(), room.GetAllInvitedUsers()[0])) + mock.ExpectCommit() + r, err := s.FetchRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.NotNil(t, r) + assert.EqualValues(t, room, r) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectQuery("SELECT (.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()).WillReturnError(errMocked) + mock.ExpectRollback() + _, err = s.FetchRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestPgSQLStorageDeleteRoom(t *testing.T) { + room := getTestRoom() + s, mock := newRoomMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM rooms (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM rooms_config (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM rooms_users (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec("DELETE FROM rooms_invites (.+)"). + WithArgs(room.RoomJID.String()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + err := s.DeleteRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + s, mock = newRoomMock() + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM rooms (.+)"). + WithArgs(room.RoomJID.String()).WillReturnError(errMocked) + mock.ExpectRollback() + + err = s.DeleteRoom(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func TestPgSQLStorageRoomExists(t *testing.T) { + room := getTestRoom() + countCols := []string{"count"} + + s, mock := newRoomMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnRows(sqlmock.NewRows(countCols).AddRow(1)) + + ok, err := s.RoomExists(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + require.True(t, ok) + + s, mock = newRoomMock() + mock.ExpectQuery("SELECT COUNT(.+) FROM rooms (.+)"). + WithArgs(room.RoomJID.String()). + WillReturnError(errMocked) + _, err = s.RoomExists(context.Background(), room.RoomJID) + require.Nil(t, mock.ExpectationsWereMet()) + require.Equal(t, errMocked, err) +} + +func newRoomMock() (*pgSQLRoom, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLRoom{ + pgSQLStorage: s, + }, sqlMock +} + +func getTestRoom() *mucmodel.Room { + rc := mucmodel.RoomConfig{ + Public: true, + Persistent: true, + PwdProtected: false, + Open: true, + Moderated: false, + } + j, _ := jid.NewWithString("testroom@conference.jackal.im", true) + + r := &mucmodel.Room{ + Name: "testRoom", + RoomJID: j, + Desc: "Room for Testing", + Config: &rc, + Locked: false, + } + + oJID, _ := jid.NewWithString("testroom@conference.jackal.im/owner", true) + owner, _ := mucmodel.NewOccupant(oJID, oJID.ToBareJID()) + r.AddOccupant(owner) + r.InviteUser(oJID.ToBareJID()) + + return r +} diff --git a/storage/pgsql/roster.go b/storage/pgsql/roster.go index ea7d7ac2d..883dc79c6 100644 --- a/storage/pgsql/roster.go +++ b/storage/pgsql/roster.go @@ -6,28 +6,39 @@ package pgsql import ( + "context" "database/sql" "encoding/json" "strings" sq "github.com/Masterminds/squirrel" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) -// InsertOrUpdateRosterItem inserts a new roster item entity into storage, -// or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Version, error) { +type pgSQLRoster struct { + *pgSQLStorage + pool *pool.BufferPool +} + +func newRoster(db *sql.DB) *pgSQLRoster { + return &pgSQLRoster{ + pgSQLStorage: newStorage(db), + } +} + +func (s *pgSQLRoster) UpsertRosterItem(ctx context.Context, ri *rostermodel.Item) (rostermodel.Version, error) { var ver rostermodel.Version - err := s.inTransaction(func(tx *sql.Tx) error { + err := s.inTransaction(ctx, func(tx *sql.Tx) error { q := sq.Insert("roster_versions"). Columns("username"). Values(ri.Username). Suffix("ON CONFLICT (username) DO UPDATE SET ver = roster_versions.ver + 1") - if _, err := q.RunWith(tx).Exec(); err != nil { + if _, err := q.RunWith(tx).ExecContext(ctx); err != nil { return err } groupsBytes, err := json.Marshal(ri.Groups) @@ -40,14 +51,14 @@ func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve Columns("username", "jid", "name", "subscription", "groups", "ask", "ver"). Values(ri.Username, ri.JID, ri.Name, ri.Subscription, groupsBytes, ri.Ask, verExpr). Suffix("ON CONFLICT (username, jid) DO UPDATE SET name = $3, subscription = $4, groups = $5, ask = $6, ver = roster_items.ver + 1") - _, err = q.RunWith(tx).Exec() + _, err = q.RunWith(tx).ExecContext(ctx) if err != nil { return err } // delete previous groups _, err = sq.Delete("roster_groups"). Where(sq.And{sq.Eq{"username": ri.Username}, sq.Eq{"jid": ri.JID}}). - RunWith(tx).Exec() + RunWith(tx).ExecContext(ctx) if err != nil { return err } @@ -56,13 +67,13 @@ func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve q = sq.Insert("roster_groups"). Columns("username", "jid", `"group"`, "created_at", "updated_at"). Values(ri.Username, ri.JID, group, nowExpr, nowExpr) - _, err := q.RunWith(tx).Exec() + _, err := q.RunWith(tx).ExecContext(ctx) if err != nil { return err } } // fetch new roster version - ver, err = fetchRosterVer(ri.Username, tx) + ver, err = fetchRosterVer(ctx, ri.Username, tx) return err }) if err != nil { @@ -71,36 +82,35 @@ func (s *Storage) InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Ve return ver, nil } -// DeleteRosterItem deletes a roster item entity from storage. -func (s *Storage) DeleteRosterItem(username, jid string) (rostermodel.Version, error) { +func (s *pgSQLRoster) DeleteRosterItem(ctx context.Context, username, jid string) (rostermodel.Version, error) { var ver rostermodel.Version - err := s.inTransaction(func(tx *sql.Tx) error { + err := s.inTransaction(ctx, func(tx *sql.Tx) error { q := sq.Insert("roster_versions"). Columns("username"). Values(username). Suffix("ON CONFLICT (username) DO UPDATE SET ver = roster_versions.ver + 1, last_deletion_ver = roster_versions.ver") - if _, err := q.RunWith(tx).Exec(); err != nil { + if _, err := q.RunWith(tx).ExecContext(ctx); err != nil { return err } // delete groups _, err := sq.Delete("roster_groups"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"jid": jid}}). - RunWith(tx).Exec() + RunWith(tx).ExecContext(ctx) if err != nil { return err } // delete items _, err = sq.Delete("roster_items"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"jid": jid}}). - RunWith(tx).Exec() + RunWith(tx).ExecContext(ctx) if err != nil { return err } // fetch new roster version - ver, err = fetchRosterVer(username, tx) + ver, err = fetchRosterVer(ctx, username, tx) return err }) if err != nil { @@ -109,65 +119,60 @@ func (s *Storage) DeleteRosterItem(username, jid string) (rostermodel.Version, e return ver, nil } -// FetchRosterItems retrieves from storage all roster item entities -// associated to a given user. -func (s *Storage) FetchRosterItems(username string) ([]rostermodel.Item, rostermodel.Version, error) { +func (s *pgSQLRoster) FetchRosterItems(ctx context.Context, username string) ([]rostermodel.Item, rostermodel.Version, error) { q := sq.Select("username", "jid", "name", "subscription", "groups", "ask", "ver"). From("roster_items"). Where(sq.Eq{"username": username}). OrderBy("created_at DESC") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, rostermodel.Version{}, err } defer func() { _ = rows.Close() }() - items, err := s.scanRosterItemEntities(rows) + items, err := scanRosterItemEntities(rows) if err != nil { return nil, rostermodel.Version{}, err } - ver, err := fetchRosterVer(username, s.db) + ver, err := fetchRosterVer(ctx, username, s.db) if err != nil { return nil, rostermodel.Version{}, err } return items, ver, nil } -// FetchRosterItemsInGroups retrieves from storage all roster item entities -// associated to a given user and a set of groups. -func (s *Storage) FetchRosterItemsInGroups(username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { +func (s *pgSQLRoster) FetchRosterItemsInGroups(ctx context.Context, username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { q := sq.Select("ris.username", "ris.jid", "ris.name", "ris.subscription", "ris.groups", "ris.ask", "ris.ver"). From("roster_items ris"). - LeftJoin("roster_groups g on ris.username = g.username"). + LeftJoin("roster_groups g ON ris.username = g.username"). Where(sq.And{sq.Eq{"ris.username": username}, sq.Eq{"g.group": groups}}). OrderBy("ris.created_at DESC") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, rostermodel.Version{}, err } defer func() { _ = rows.Close() }() - items, err := s.scanRosterItemEntities(rows) + items, err := scanRosterItemEntities(rows) if err != nil { return nil, rostermodel.Version{}, err } - ver, err := fetchRosterVer(username, s.db) + ver, err := fetchRosterVer(ctx, username, s.db) if err != nil { return nil, rostermodel.Version{}, err } return items, ver, nil } -// FetchRosterItem retrieves from storage a roster item entity. -func (s *Storage) FetchRosterItem(username, jid string) (*rostermodel.Item, error) { +func (s *pgSQLRoster) FetchRosterItem(ctx context.Context, username, jid string) (*rostermodel.Item, error) { q := sq.Select("username", "jid", "name", "subscription", "groups", "ask", "ver"). From("roster_items"). Where(sq.And{sq.Eq{"username": username}, sq.Eq{"jid": jid}}) var ri rostermodel.Item - err := s.scanRosterItemEntity(&ri, q.RunWith(s.db).QueryRow()) + err := scanRosterItemEntity(&ri, q.RunWith(s.db).QueryRowContext(ctx)) switch err { case nil: return &ri, nil @@ -178,9 +183,7 @@ func (s *Storage) FetchRosterItem(username, jid string) (*rostermodel.Item, erro } } -// InsertOrUpdateRosterNotification inserts a new roster notification entity -// into storage, or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateRosterNotification(rn *rostermodel.Notification) error { +func (s *pgSQLRoster) UpsertRosterNotification(ctx context.Context, rn *rostermodel.Notification) error { presenceXML := rn.Presence.String() q := sq.Insert("roster_notifications"). @@ -188,20 +191,18 @@ func (s *Storage) InsertOrUpdateRosterNotification(rn *rostermodel.Notification) Values(rn.Contact, rn.JID, presenceXML). Suffix("ON CONFLICT (contact, jid) DO UPDATE SET elements = $4", presenceXML) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -// FetchRosterNotifications retrieves from storage all roster notifications -// associated to a given user. -func (s *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { +func (s *pgSQLRoster) FetchRosterNotifications(ctx context.Context, contact string) ([]rostermodel.Notification, error) { q := sq.Select("contact", "jid", "elements"). From("roster_notifications"). Where(sq.Eq{"contact": contact}). OrderBy("created_at") - rows, err := q.RunWith(s.db).Query() + rows, err := q.RunWith(s.db).QueryContext(ctx) if err != nil { return nil, err } @@ -210,7 +211,7 @@ func (s *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notifi var ret []rostermodel.Notification for rows.Next() { var rn rostermodel.Notification - if err := s.scanRosterNotificationEntity(&rn, rows); err != nil { + if err := scanRosterNotificationEntity(&rn, rows); err != nil { return nil, err } ret = append(ret, rn) @@ -218,14 +219,13 @@ func (s *Storage) FetchRosterNotifications(contact string) ([]rostermodel.Notifi return ret, nil } -// FetchRosterNotification retrieves from storage a roster notification entity. -func (s *Storage) FetchRosterNotification(contact string, jid string) (*rostermodel.Notification, error) { +func (s *pgSQLRoster) FetchRosterNotification(ctx context.Context, contact string, jid string) (*rostermodel.Notification, error) { q := sq.Select("contact", "jid", "elements"). From("roster_notifications"). Where(sq.And{sq.Eq{"contact": contact}, sq.Eq{"jid": jid}}) var rn rostermodel.Notification - err := s.scanRosterNotificationEntity(&rn, q.RunWith(s.db).QueryRow()) + err := scanRosterNotificationEntity(&rn, q.RunWith(s.db).QueryRowContext(ctx)) switch err { case nil: return &rn, nil @@ -236,14 +236,36 @@ func (s *Storage) FetchRosterNotification(contact string, jid string) (*rostermo } } -// DeleteRosterNotification deletes a roster notification entity from storage. -func (s *Storage) DeleteRosterNotification(contact, jid string) error { +func (s *pgSQLRoster) DeleteRosterNotification(ctx context.Context, contact, jid string) error { q := sq.Delete("roster_notifications").Where(sq.And{sq.Eq{"contact": contact}, sq.Eq{"jid": jid}}) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -func (s *Storage) scanRosterNotificationEntity(rn *rostermodel.Notification, scanner rowScanner) error { +func (s *pgSQLRoster) FetchRosterGroups(ctx context.Context, username string) ([]string, error) { + q := sq.Select("`group`"). + From("roster_groups"). + Where(sq.Eq{"username": username}). + GroupBy("`group`") + + rows, err := q.RunWith(s.db).QueryContext(ctx) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var groups []string + for rows.Next() { + var group string + if err := rows.Scan(&group); err != nil { + return nil, err + } + groups = append(groups, group) + } + return groups, nil +} + +func scanRosterNotificationEntity(rn *rostermodel.Notification, scanner rowScanner) error { var presenceXML string if err := scanner.Scan(&rn.Contact, &rn.JID, &presenceXML); err != nil { return err @@ -259,7 +281,7 @@ func (s *Storage) scanRosterNotificationEntity(rn *rostermodel.Notification, sca return nil } -func (s *Storage) scanRosterItemEntity(ri *rostermodel.Item, scanner rowScanner) error { +func scanRosterItemEntity(ri *rostermodel.Item, scanner rowScanner) error { var groupsBytes string if err := scanner.Scan(&ri.Username, &ri.JID, &ri.Name, &ri.Subscription, &groupsBytes, &ri.Ask, &ri.Ver); err != nil { return err @@ -272,11 +294,11 @@ func (s *Storage) scanRosterItemEntity(ri *rostermodel.Item, scanner rowScanner) return nil } -func (s *Storage) scanRosterItemEntities(scanner rowsScanner) ([]rostermodel.Item, error) { +func scanRosterItemEntities(scanner rowsScanner) ([]rostermodel.Item, error) { var ret []rostermodel.Item for scanner.Next() { var ri rostermodel.Item - if err := s.scanRosterItemEntity(&ri, scanner); err != nil { + if err := scanRosterItemEntity(&ri, scanner); err != nil { return nil, err } ret = append(ret, ri) @@ -284,13 +306,13 @@ func (s *Storage) scanRosterItemEntities(scanner rowsScanner) ([]rostermodel.Ite return ret, nil } -func fetchRosterVer(username string, runner sq.BaseRunner) (rostermodel.Version, error) { +func fetchRosterVer(ctx context.Context, username string, runner sq.BaseRunner) (rostermodel.Version, error) { q := sq.Select("COALESCE(MAX(ver), 0)", "COALESCE(MAX(last_deletion_ver), 0)"). From("roster_versions"). Where(sq.Eq{"username": username}) var ver rostermodel.Version - row := q.RunWith(runner).QueryRow() + row := q.RunWith(runner).QueryRowContext(ctx) err := row.Scan(&ver.Ver, &ver.DeletionVer) switch err { case nil: diff --git a/storage/pgsql/roster_test.go b/storage/pgsql/roster_test.go index 73b5988dc..0b7e63e62 100644 --- a/storage/pgsql/roster_test.go +++ b/storage/pgsql/roster_test.go @@ -6,12 +6,13 @@ package pgsql import ( + "context" "database/sql/driver" "encoding/json" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" - "github.com/ortuman/jackal/model/rostermodel" + rostermodel "github.com/ortuman/jackal/model/roster" "github.com/ortuman/jackal/xmpp" "github.com/stretchr/testify/require" ) @@ -39,7 +40,7 @@ func TestInsertRosterItem(t *testing.T) { ri.Username, } - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectBegin() @@ -69,13 +70,13 @@ func TestInsertRosterItem(t *testing.T) { mock.ExpectCommit() - _, err := s.InsertOrUpdateRosterItem(&ri) + _, err := s.UpsertRosterItem(context.Background(), &ri) require.Nil(t, err) require.Nil(t, mock.ExpectationsWereMet()) } func TestDeleteRosterItem(t *testing.T) { - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectBegin() mock.ExpectExec("INSERT INTO roster_versions (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs("user").WillReturnResult(sqlmock.NewResult(0, 1)) @@ -88,17 +89,17 @@ func TestDeleteRosterItem(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"ver", "deletionVer"}).AddRow(1, 0)) mock.ExpectCommit() - _, err := s.DeleteRosterItem("user", "contact") + _, err := s.DeleteRosterItem(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectBegin() mock.ExpectExec("INSERT INTO roster_versions (.+)"). WithArgs("user").WillReturnError(errGeneric) mock.ExpectRollback() - _, err = s.DeleteRosterItem("user", "contact") + _, err = s.DeleteRosterItem(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) } @@ -106,7 +107,7 @@ func TestDeleteRosterItem(t *testing.T) { func TestFetchRosterItems(t *testing.T) { var riColumns = []string{"user", "contact", "name", "subscription", "`groups`", "ask", "ver"} - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(riColumns).AddRow("ortuman", "romeo", "Romeo", "both", "", false, 0)) @@ -114,57 +115,57 @@ func TestFetchRosterItems(t *testing.T) { WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows([]string{"ver", "deletionVer"}).AddRow(0, 0)) - rosterItems, _, err := s.FetchRosterItems("ortuman") + rosterItems, _, err := s.FetchRosterItems(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 1, len(rosterItems)) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman"). WillReturnError(errGeneric) - _, _, err = s.FetchRosterItems("ortuman") + _, _, err = s.FetchRosterItems(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman", "romeo"). WillReturnRows(sqlmock.NewRows(riColumns).AddRow("ortuman", "romeo", "Romeo", "both", "", false, 0)) - _, err = s.FetchRosterItem("ortuman", "romeo") + _, err = s.FetchRosterItem(context.Background(), "ortuman", "romeo") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman", "romeo"). WillReturnRows(sqlmock.NewRows(riColumns)) - ri, _ := s.FetchRosterItem("ortuman", "romeo") + ri, _ := s.FetchRosterItem(context.Background(), "ortuman", "romeo") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, ri) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_items (.+)"). WithArgs("ortuman", "romeo"). WillReturnError(errGeneric) - _, err = s.FetchRosterItem("ortuman", "romeo") + _, err = s.FetchRosterItem(context.Background(), "ortuman", "romeo") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) var riColumns2 = []string{"ris.user", "ris.contact", "ris.name", "ris.subscription", "ris.groups", "ris.ask", "ris.ver"} - s, mock = NewMock() - mock.ExpectQuery("SELECT (.+) FROM roster_items ris LEFT JOIN roster_groups g on ris.username = g.username (.+)"). + s, mock = newRosterMock() + mock.ExpectQuery("SELECT (.+) FROM roster_items ris LEFT JOIN roster_groups g ON ris.username = g.username (.+)"). WithArgs("ortuman", "Family"). WillReturnRows(sqlmock.NewRows(riColumns2).AddRow("ortuman", "romeo", "Romeo", "both", `["Family"]`, false, 0)) mock.ExpectQuery("SELECT (.+) FROM roster_versions (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows([]string{"ver", "deletionVer"}).AddRow(0, 0)) - _, _, err = s.FetchRosterItemsInGroups("ortuman", []string{"Family"}) + _, _, err = s.FetchRosterItemsInGroups(context.Background(), "ortuman", []string{"Family"}) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) } @@ -183,39 +184,39 @@ func TestInsertRosterNotification(t *testing.T) { presenceXML, presenceXML, } - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectExec("INSERT INTO roster_notifications (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs(args...). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdateRosterNotification(&rn) + err := s.UpsertRosterNotification(context.Background(), &rn) require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectExec("INSERT INTO roster_notifications (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs(args...). WillReturnError(errGeneric) - err = s.InsertOrUpdateRosterNotification(&rn) + err = s.UpsertRosterNotification(context.Background(), &rn) require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) } func TestDeleteRosterNotification(t *testing.T) { - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectExec("DELETE FROM roster_notifications (.+)"). WithArgs("user", "contact").WillReturnResult(sqlmock.NewResult(0, 1)) - err := s.DeleteRosterNotification("user", "contact") + err := s.DeleteRosterNotification(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectExec("DELETE FROM roster_notifications (.+)"). WithArgs("user", "contact").WillReturnError(errGeneric) - err = s.DeleteRosterNotification("user", "contact") + err = s.DeleteRosterNotification(context.Background(), "user", "contact") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) } @@ -223,41 +224,77 @@ func TestDeleteRosterNotification(t *testing.T) { func TestFetchRosterNotifications(t *testing.T) { var rnColumns = []string{"user", "contact", "elements"} - s, mock := NewMock() + s, mock := newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(rnColumns).AddRow("romeo", "contact", "8")) - rosterNotifications, err := s.FetchRosterNotifications("ortuman") + rosterNotifications, err := s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 1, len(rosterNotifications)) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(rnColumns)) - rosterNotifications, err = s.FetchRosterNotifications("ortuman") + rosterNotifications, err = s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Equal(t, 0, len(rosterNotifications)) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnError(errGeneric) - _, err = s.FetchRosterNotifications("ortuman") + _, err = s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Equal(t, errGeneric, err) - s, mock = NewMock() + s, mock = newRosterMock() mock.ExpectQuery("SELECT (.+) FROM roster_notifications (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(rnColumns).AddRow("romeo", "contact", "8")) - _, err = s.FetchRosterNotifications("ortuman") + _, err = s.FetchRosterNotifications(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.NotNil(t, err) } + +func TestStorageFetchRosterGroups(t *testing.T) { + s, mock := newRosterMock() + mock.ExpectQuery("SELECT `group` FROM roster_groups WHERE username = (.+) GROUP BY (.+)"). + WithArgs("ortuman"). + WillReturnRows(sqlmock.NewRows([]string{"group"}). + AddRow("Contacts"). + AddRow("News")) + + groups, err := s.FetchRosterGroups(context.Background(), "ortuman") + require.Nil(t, mock.ExpectationsWereMet()) + require.Nil(t, err) + + require.Equal(t, 2, len(groups)) + require.Equal(t, "Contacts", groups[0]) + require.Equal(t, "News", groups[1]) + + s, mock = newRosterMock() + mock.ExpectQuery("SELECT `group` FROM roster_groups WHERE username = (.+) GROUP BY (.+)"). + WithArgs("ortuman"). + WillReturnError(errGeneric) + + groups, err = s.FetchRosterGroups(context.Background(), "ortuman") + require.Nil(t, mock.ExpectationsWereMet()) + + require.Nil(t, groups) + require.NotNil(t, err) + require.Equal(t, errGeneric, err) +} + +func newRosterMock() (*pgSQLRoster, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLRoster{ + pgSQLStorage: s, + }, sqlMock +} diff --git a/storage/pgsql/sql.go b/storage/pgsql/sql.go deleted file mode 100644 index 6e70ab506..000000000 --- a/storage/pgsql/sql.go +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package pgsql - -import ( - "context" - "database/sql" - "fmt" - "time" - - sq "github.com/Masterminds/squirrel" - _ "github.com/lib/pq" // PostgreSQL driver - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/pool" -) - -// pingInterval defines how often to check the connection -var pingInterval = 15 * time.Second - -// pingTimeout defines how long to wait for pong from server -var pingTimeout = 10 * time.Second - -var ( - nowExpr = sq.Expr("NOW()") -) - -type rowScanner interface { - Scan(...interface{}) error -} - -type rowsScanner interface { - rowScanner - Next() bool -} - -// Storage represents a SQL storage sub system. -type Storage struct { - db *sql.DB - pool *pool.BufferPool - cancelPing context.CancelFunc -} - -// New instantiates a PostgreSQL storage instance. -func New(c *Config) *Storage { - var err error - - sq.StatementBuilder = sq.StatementBuilder.PlaceholderFormat(sq.Dollar) - - s := &Storage{ - pool: pool.NewBufferPool(), - } - dsn := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s", c.User, c.Password, c.Host, c.Database, c.SSLMode) - - s.db, err = sql.Open("postgres", dsn) - if err != nil { - log.Fatalf("%v", err) - } - - s.db.SetMaxOpenConns(c.PoolSize) // set max opened connection count - - s.ping(context.Background()) - - ctx, cancel := context.WithCancel(context.Background()) - s.cancelPing = cancel - go s.pingLoop(ctx) - - return s -} - -// IsClusterCompatible returns whether or not the underlying storage subsystem can be used in cluster mode. -func (s *Storage) IsClusterCompatible() bool { return true } - -// Close shuts down SQL storage sub system. -func (s *Storage) Close() error { - s.cancelPing() // Stop pinging the server - - return s.db.Close() -} - -// pingLoop periodically pings the server to check connection status -func (s *Storage) pingLoop(ctx context.Context) { - tick := time.NewTicker(pingInterval) - defer tick.Stop() - - for { - select { - case <-tick.C: - s.ping(ctx) - case <-ctx.Done(): - return - } - } -} - -// ping sends a ping request to the server and outputs any error to log -func (s *Storage) ping(ctx context.Context) { - pingCtx, cancel := context.WithDeadline(ctx, time.Now().Add(pingTimeout)) - defer cancel() - - err := s.db.PingContext(pingCtx) - - if err != nil { - log.Error(err) - } -} - -func (s *Storage) inTransaction(f func(tx *sql.Tx) error) error { - tx, txErr := s.db.Begin() - - if txErr != nil { - return txErr - } - - if err := f(tx); err != nil { - tx.Rollback() - return err - } - return tx.Commit() -} diff --git a/storage/pgsql/sql_test.go b/storage/pgsql/sql_test.go deleted file mode 100644 index 82bd1d491..000000000 --- a/storage/pgsql/sql_test.go +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package pgsql - -import ( - "errors" - - sqlmock "github.com/DATA-DOG/go-sqlmock" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/pool" -) - -var ( - errGeneric = errors.New("pgsql: generic storage error") -) - -// NewMock returns a mocked SQL storage instance. -func NewMock() (*Storage, sqlmock.Sqlmock) { - var err error - var sqlMock sqlmock.Sqlmock - - s := &Storage{ - pool: pool.NewBufferPool(), - } - - s.db, sqlMock, err = sqlmock.New() - - if err != nil { - log.Fatalf("%v", err) - } - - return s, sqlMock -} diff --git a/storage/pgsql/storage.go b/storage/pgsql/storage.go new file mode 100644 index 000000000..be6f29788 --- /dev/null +++ b/storage/pgsql/storage.go @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + "context" + "database/sql" + "errors" + + sq "github.com/Masterminds/squirrel" +) + +var ( + nowExpr = sq.Expr("NOW()") +) + +type rowScanner interface { + Scan(...interface{}) error +} + +type rowsScanner interface { + rowScanner + Next() bool +} + +// pgSQLStorage represents a SQL storage base sub system. +type pgSQLStorage struct { + db *sql.DB +} + +var ( + errMocked = errors.New("pgsql: storage error") +) + +// newStorage instantiates a PostgreSQL base storage instance. +func newStorage(db *sql.DB) *pgSQLStorage { + return &pgSQLStorage{db: db} +} + +func (s *pgSQLStorage) inTransaction(ctx context.Context, f func(tx *sql.Tx) error) error { + tx, txErr := s.db.BeginTx(ctx, nil) + if txErr != nil { + return txErr + } + if err := f(tx); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} diff --git a/storage/pgsql/storage_test.go b/storage/pgsql/storage_test.go new file mode 100644 index 000000000..77c5b8d31 --- /dev/null +++ b/storage/pgsql/storage_test.go @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package pgsql + +import ( + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/ortuman/jackal/log" +) + +// newStorageMock returns a mocked MySQL storage instance. +func newStorageMock() (*pgSQLStorage, sqlmock.Sqlmock) { + db, sqlMock, err := sqlmock.New() + if err != nil { + log.Fatalf("%v", err) + } + return &pgSQLStorage{db: db}, sqlMock +} diff --git a/storage/pgsql/user.go b/storage/pgsql/user.go index 723cbe29f..c4b6ee912 100644 --- a/storage/pgsql/user.go +++ b/storage/pgsql/user.go @@ -6,45 +6,60 @@ package pgsql import ( + "context" "database/sql" "strings" "time" sq "github.com/Masterminds/squirrel" "github.com/ortuman/jackal/model" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" ) -// InsertOrUpdateUser inserts a new user entity into storage, -// or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateUser(u *model.User) error { +type pgSQLUser struct { + *pgSQLStorage + pool *pool.BufferPool +} + +func newUser(db *sql.DB) *pgSQLUser { + return &pgSQLUser{ + pgSQLStorage: newStorage(db), + pool: pool.NewBufferPool(), + } +} + +// UpsertUser inserts a new user entity into storage, or updates it in case it's been previously inserted. +func (u *pgSQLUser) UpsertUser(ctx context.Context, usr *model.User) error { var presenceXML string - if u.LastPresence != nil { - buf := s.pool.Get() - u.LastPresence.ToXML(buf, true) + if usr.LastPresence != nil { + buf := u.pool.Get() + if err := usr.LastPresence.ToXML(buf, true); err != nil { + return err + } presenceXML = buf.String() - s.pool.Put(buf) + u.pool.Put(buf) } q := sq.Insert("users") if len(presenceXML) > 0 { q = q.Columns("username", "password", "last_presence", "last_presence_at"). - Values(u.Username, u.Password, presenceXML, nowExpr). + Values(usr.Username, usr.Password, presenceXML, nowExpr). Suffix("ON CONFLICT (username) DO UPDATE SET password = $2, last_presence = $3, last_presence_at = NOW()") } else { q = q.Columns("username", "password"). - Values(u.Username, u.Password). + Values(usr.Username, usr.Password). Suffix("ON CONFLICT (username) DO UPDATE SET password = $2") } - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(u.db).ExecContext(ctx) return err } // FetchUser retrieves from storage a user entity. -func (s *Storage) FetchUser(username string) (*model.User, error) { +func (u *pgSQLUser) FetchUser(ctx context.Context, username string) (*model.User, error) { q := sq.Select("username", "password", "last_presence", "last_presence_at"). From("users"). Where(sq.Eq{"username": username}) @@ -53,7 +68,7 @@ func (s *Storage) FetchUser(username string) (*model.User, error) { var presenceAt time.Time var usr model.User - err := q.RunWith(s.db).QueryRow().Scan(&usr.Username, &usr.Password, &presenceXML, &presenceAt) + err := q.RunWith(u.db).QueryRowContext(ctx).Scan(&usr.Username, &usr.Password, &presenceXML, &presenceAt) switch err { case nil: if len(presenceXML) > 0 { @@ -76,30 +91,30 @@ func (s *Storage) FetchUser(username string) (*model.User, error) { } // DeleteUser deletes a user entity from storage. -func (s *Storage) DeleteUser(username string) error { - return s.inTransaction(func(tx *sql.Tx) error { +func (u *pgSQLUser) DeleteUser(ctx context.Context, username string) error { + return u.inTransaction(ctx, func(tx *sql.Tx) error { var err error - _, err = sq.Delete("offline_messages").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("offline_messages").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("roster_items").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("roster_items").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("roster_versions").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("roster_versions").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("private_storage").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("private_storage").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("vcards").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("vcards").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } - _, err = sq.Delete("users").Where(sq.Eq{"username": username}).RunWith(tx).Exec() + _, err = sq.Delete("users").Where(sq.Eq{"username": username}).RunWith(tx).ExecContext(ctx) if err != nil { return err } @@ -108,10 +123,11 @@ func (s *Storage) DeleteUser(username string) error { } // UserExists returns whether or not a user exists within storage. -func (s *Storage) UserExists(username string) (bool, error) { - q := sq.Select("COUNT(*)").From("users").Where(sq.Eq{"username": username}) +func (u *pgSQLUser) UserExists(ctx context.Context, username string) (bool, error) { var count int - err := q.RunWith(s.db).QueryRow().Scan(&count) + + q := sq.Select("COUNT(*)").From("users").Where(sq.Eq{"username": username}) + err := q.RunWith(u.db).QueryRowContext(ctx).Scan(&count) switch err { case nil: return count > 0, nil diff --git a/storage/pgsql/user_test.go b/storage/pgsql/user_test.go index 9d01650a8..bd663cf71 100644 --- a/storage/pgsql/user_test.go +++ b/storage/pgsql/user_test.go @@ -6,11 +6,13 @@ package pgsql import ( + "context" "testing" "time" sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/ortuman/jackal/model" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" "github.com/stretchr/testify/require" @@ -23,27 +25,27 @@ func TestInsertUser(t *testing.T) { user := model.User{Username: "ortuman", Password: "1234", LastPresence: p} - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectExec("INSERT INTO users (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs(user.Username, user.Password, user.LastPresence.String()). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdateUser(&user) + err := s.UpsertUser(context.Background(), &user) require.Nil(t, err) require.Nil(t, mock.ExpectationsWereMet()) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectExec("INSERT INTO users (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs(user.Username, user.Password, user.LastPresence.String()). - WillReturnError(errGeneric) + WillReturnError(errMocked) - err = s.InsertOrUpdateUser(&user) - require.Equal(t, errGeneric, err) + err = s.UpsertUser(context.Background(), &user) + require.Equal(t, errMocked, err) require.Nil(t, mock.ExpectationsWereMet()) } func TestDeleteUser(t *testing.T) { - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectBegin() mock.ExpectExec("DELETE FROM offline_messages (.+)"). WithArgs("ortuman").WillReturnResult(sqlmock.NewResult(0, 1)) @@ -59,19 +61,19 @@ func TestDeleteUser(t *testing.T) { WithArgs("ortuman").WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() - err := s.DeleteUser("ortuman") + err := s.DeleteUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectBegin() mock.ExpectExec("DELETE FROM offline_messages (.+)"). - WithArgs("ortuman").WillReturnError(errGeneric) + WithArgs("ortuman").WillReturnError(errMocked) mock.ExpectRollback() - err = s.DeleteUser("ortuman") + err = s.DeleteUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) - require.Equal(t, errGeneric, err) + require.Equal(t, errMocked, err) } func TestFetchUser(t *testing.T) { @@ -81,49 +83,57 @@ func TestFetchUser(t *testing.T) { var userColumns = []string{"username", "password", "last_presence", "last_presence_at"} - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectQuery("SELECT (.+) FROM users (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(userColumns)) - usr, _ := s.FetchUser("ortuman") + usr, _ := s.FetchUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, usr) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectQuery("SELECT (.+) FROM users (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(userColumns).AddRow("ortuman", "1234", p.String(), time.Now())) - _, err := s.FetchUser("ortuman") + _, err := s.FetchUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectQuery("SELECT (.+) FROM users (.+)"). - WithArgs("ortuman").WillReturnError(errGeneric) - _, err = s.FetchUser("ortuman") + WithArgs("ortuman").WillReturnError(errMocked) + _, err = s.FetchUser(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) - require.Equal(t, errGeneric, err) + require.Equal(t, errMocked, err) } func TestUserExists(t *testing.T) { countColums := []string{"count"} - s, mock := NewMock() + s, mock := newUserMock() mock.ExpectQuery("SELECT COUNT(.+) FROM users (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(countColums).AddRow(1)) - ok, err := s.UserExists("ortuman") + ok, err := s.UserExists(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.True(t, ok) - s, mock = NewMock() + s, mock = newUserMock() mock.ExpectQuery("SELECT COUNT(.+) FROM users (.+)"). WithArgs("romeo"). - WillReturnError(errGeneric) - _, err = s.UserExists("romeo") + WillReturnError(errMocked) + _, err = s.UserExists(context.Background(), "romeo") require.Nil(t, mock.ExpectationsWereMet()) - require.Equal(t, errGeneric, err) + require.Equal(t, errMocked, err) +} + +func newUserMock() (*pgSQLUser, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLUser{ + pgSQLStorage: s, + pool: pool.NewBufferPool(), + }, sqlMock } diff --git a/storage/pgsql/vcard.go b/storage/pgsql/vcard.go index 5999ee57d..45c6436f0 100644 --- a/storage/pgsql/vcard.go +++ b/storage/pgsql/vcard.go @@ -6,6 +6,7 @@ package pgsql import ( + "context" "database/sql" "strings" @@ -13,28 +14,36 @@ import ( "github.com/ortuman/jackal/xmpp" ) -// InsertOrUpdateVCard inserts a new vCard element into storage, -// or updates it in case it's been previously inserted. -func (s *Storage) InsertOrUpdateVCard(vCard xmpp.XElement, username string) error { +type pgSQLVCard struct { + *pgSQLStorage +} + +func newVCard(db *sql.DB) *pgSQLVCard { + return &pgSQLVCard{ + pgSQLStorage: newStorage(db), + } +} + +// UpsertVCard inserts a new vCard element into storage, or updates it in case it's been previously inserted. +func (s *pgSQLVCard) UpsertVCard(ctx context.Context, vCard xmpp.XElement, username string) error { rawXML := vCard.String() q := sq.Insert("vcards"). Columns("username", "vcard"). Values(username, rawXML). - Suffix("ON CONFLICT (username) DO UPDATE SET vcard = ?", rawXML) + Suffix("ON CONFLICT (username) DO UPDATE SET vcard = $3", rawXML) - _, err := q.RunWith(s.db).Exec() + _, err := q.RunWith(s.db).ExecContext(ctx) return err } -// FetchVCard retrieves from storage a vCard element associated -// to a given user. -func (s *Storage) FetchVCard(username string) (xmpp.XElement, error) { +// FetchVCard retrieves from storage a vCard element associated to a given user. +func (s *pgSQLVCard) FetchVCard(ctx context.Context, username string) (xmpp.XElement, error) { q := sq.Select("vcard").From("vcards").Where(sq.Eq{"username": username}) var vCard string - err := q.RunWith(s.db).QueryRow().Scan(&vCard) + err := q.RunWith(s.db).QueryRowContext(ctx).Scan(&vCard) switch err { case nil: diff --git a/storage/pgsql/vcard_test.go b/storage/pgsql/vcard_test.go index 23d1fa40e..3fd0eb8ba 100644 --- a/storage/pgsql/vcard_test.go +++ b/storage/pgsql/vcard_test.go @@ -6,6 +6,7 @@ package pgsql import ( + "context" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" @@ -17,22 +18,22 @@ func TestInsertVCard(t *testing.T) { vCard := xmpp.NewElementName("vCard") rawXML := vCard.String() - s, mock := NewMock() + s, mock := newVCardMock() mock.ExpectExec("INSERT INTO vcards (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs("ortuman", rawXML, rawXML). WillReturnResult(sqlmock.NewResult(1, 1)) - err := s.InsertOrUpdateVCard(vCard, "ortuman") + err := s.UpsertVCard(context.Background(), vCard, "ortuman") require.Nil(t, err) require.NotNil(t, vCard) require.Nil(t, mock.ExpectationsWereMet()) - s, mock = NewMock() + s, mock = newVCardMock() mock.ExpectExec("INSERT INTO vcards (.+) ON CONFLICT (.+) DO UPDATE SET (.+)"). WithArgs("ortuman", rawXML, rawXML). WillReturnError(errGeneric) - err = s.InsertOrUpdateVCard(vCard, "ortuman") + err = s.UpsertVCard(context.Background(), vCard, "ortuman") require.Equal(t, errGeneric, err) require.Nil(t, mock.ExpectationsWereMet()) } @@ -40,32 +41,39 @@ func TestInsertVCard(t *testing.T) { func TestFetchVCard(t *testing.T) { var vCardColumns = []string{"vcard"} - s, mock := NewMock() + s, mock := newVCardMock() mock.ExpectQuery("SELECT (.+) FROM vcards (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(vCardColumns).AddRow("Miguel Ɓngel")) - vCard, err := s.FetchVCard("ortuman") + vCard, err := s.FetchVCard(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.NotNil(t, vCard) - s, mock = NewMock() + s, mock = newVCardMock() mock.ExpectQuery("SELECT (.+) FROM vcards (.+)"). WithArgs("ortuman"). WillReturnRows(sqlmock.NewRows(vCardColumns)) - vCard, err = s.FetchVCard("ortuman") + vCard, err = s.FetchVCard(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, err) require.Nil(t, vCard) - s, mock = NewMock() + s, mock = newVCardMock() mock.ExpectQuery("SELECT (.+) FROM vcards (.+)"). WithArgs("ortuman"). WillReturnError(errGeneric) - vCard, _ = s.FetchVCard("ortuman") + vCard, _ = s.FetchVCard(context.Background(), "ortuman") require.Nil(t, mock.ExpectationsWereMet()) require.Nil(t, vCard) } + +func newVCardMock() (*pgSQLVCard, sqlmock.Sqlmock) { + s, sqlMock := newStorageMock() + return &pgSQLVCard{ + pgSQLStorage: s, + }, sqlMock +} diff --git a/storage/private.go b/storage/private.go deleted file mode 100644 index c36f10885..000000000 --- a/storage/private.go +++ /dev/null @@ -1,20 +0,0 @@ -package storage - -import "github.com/ortuman/jackal/xmpp" - -// privateStorage defines operations for private storage -type privateStorage interface { - FetchPrivateXML(namespace string, username string) ([]xmpp.XElement, error) - InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace string, username string) error -} - -// FetchPrivateXML retrieves from storage a private element. -func FetchPrivateXML(namespace string, username string) ([]xmpp.XElement, error) { - return instance().FetchPrivateXML(namespace, username) -} - -// InsertOrUpdatePrivateXML inserts a new private element into storage, -// or updates it in case it's been previously inserted. -func InsertOrUpdatePrivateXML(privateXML []xmpp.XElement, namespace string, username string) error { - return instance().InsertOrUpdatePrivateXML(privateXML, namespace, username) -} diff --git a/storage/repository/block_list.go b/storage/repository/block_list.go new file mode 100644 index 000000000..d55b664ad --- /dev/null +++ b/storage/repository/block_list.go @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + + "github.com/ortuman/jackal/model" +) + +// BlockList defines storage operations for user's block list +type BlockList interface { + // InsertBlockListItem inserts a block list item entity into storage if not previously inserted. + InsertBlockListItem(ctx context.Context, item *model.BlockListItem) error + + // DeleteBlockListItem deletes a block list item entity from storage. + DeleteBlockListItem(ctx context.Context, item *model.BlockListItem) error + + // FetchBlockListItems retrieves from storage all block list item entities associated to a given user. + FetchBlockListItems(ctx context.Context, username string) ([]model.BlockListItem, error) +} diff --git a/storage/repository/container.go b/storage/repository/container.go new file mode 100644 index 000000000..5ee4cf9f0 --- /dev/null +++ b/storage/repository/container.go @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import "context" + +// Container interface brings together all repository instances. +type Container interface { + // User method returns repository.User concrete implementation. + User() User + + // Roster method returns repository.Roster concrete implementation. + Roster() Roster + + // Presences method returns repository.Presences concrete implementation. + Presences() Presences + + // VCard method returns repository.VCard concrete implementation. + VCard() VCard + + // Private method returns repository.Private concrete implementation. + Private() Private + + // BlockList method returns repository.BlockList concrete implementation. + BlockList() BlockList + + // PubSub method returns repository.PubSub concrete implementation. + PubSub() PubSub + + // Offline method returns repository.Offline concrete implementation. + Offline() Offline + + // Close closes underlying storage resources, commonly shared across repositories. + Close(ctx context.Context) error + + // IsClusterCompatible tells whether or not container instance can be safely used across multiple cluster nodes. + IsClusterCompatible() bool + + // Room method returns respository.Room concrete implementation + Room() Room + + // Occupant method returns repository.Occupant concrete implementation + Occupant() Occupant +} diff --git a/storage/repository/occupant.go b/storage/repository/occupant.go new file mode 100644 index 000000000..520761c22 --- /dev/null +++ b/storage/repository/occupant.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +// User defines user repository operations +type Occupant interface { + // UpsertOccupant inserts a new occupant entity into storage, or updates it if previously inserted. + UpsertOccupant(ctx context.Context, occ *mucmodel.Occupant) error + + // DeleteOccupant deletes a occupant entity from storage. + DeleteOccupant(ctx context.Context, occJID *jid.JID) error + + // FetchOccupant retrieves an occupant entity from storage. + FetchOccupant(ctx context.Context, occJID *jid.JID) (*mucmodel.Occupant, error) + + // OccupantExists tells whether or not an occupant exists within storage. + OccupantExists(ctx context.Context, occJID *jid.JID) (bool, error) +} diff --git a/storage/repository/offline.go b/storage/repository/offline.go new file mode 100644 index 000000000..ddf62e57f --- /dev/null +++ b/storage/repository/offline.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + + "github.com/ortuman/jackal/xmpp" +) + +// Offline defines storage operations for offline messages +type Offline interface { + // InsertOfflineMessage inserts a new message element into user's offline queue. + InsertOfflineMessage(ctx context.Context, message *xmpp.Message, username string) error + + // CountOfflineMessages returns current length of user's offline queue. + CountOfflineMessages(ctx context.Context, username string) (int, error) + + // FetchOfflineMessages retrieves from storage current user offline queue. + FetchOfflineMessages(ctx context.Context, username string) ([]xmpp.Message, error) + + // DeleteOfflineMessages clears a user offline queue. + DeleteOfflineMessages(ctx context.Context, username string) error +} diff --git a/storage/repository/presences.go b/storage/repository/presences.go new file mode 100644 index 000000000..e282dfe49 --- /dev/null +++ b/storage/repository/presences.go @@ -0,0 +1,36 @@ +package repository + +import ( + "context" + + capsmodel "github.com/ortuman/jackal/model/capabilities" + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +type Presences interface { + // UpsertPresence inserts or updates a presence and links it to certain allocation. + // On insertion 'inserted' return parameter will be true. + UpsertPresence(ctx context.Context, presence *xmpp.Presence, jid *jid.JID, allocationID string) (inserted bool, err error) + + // FetchPresence retrieves from storage a previously registered presence. + FetchPresence(ctx context.Context, jid *jid.JID) (*capsmodel.PresenceCaps, error) + + // FetchPresencesMatchingJID retrives all storage presences matching a certain JID + FetchPresencesMatchingJID(ctx context.Context, jid *jid.JID) ([]capsmodel.PresenceCaps, error) + + // DeletePresence removes from storage a concrete registered presence. + DeletePresence(ctx context.Context, jid *jid.JID) error + + // DeleteAllocationPresences removes from storage all presences associated to a given allocation. + DeleteAllocationPresences(ctx context.Context, allocationID string) error + + // ClearPresences wipes out all storage presences. + ClearPresences(ctx context.Context) error + + // UpsertCapabilities inserts capabilities associated to a node+ver pair, or updates them if previously inserted.. + UpsertCapabilities(ctx context.Context, caps *capsmodel.Capabilities) error + + // FetchCapabilities fetches capabilities associated to a give node and ver. + FetchCapabilities(ctx context.Context, node, ver string) (*capsmodel.Capabilities, error) +} diff --git a/storage/repository/private.go b/storage/repository/private.go new file mode 100644 index 000000000..7d68ed751 --- /dev/null +++ b/storage/repository/private.go @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + + "github.com/ortuman/jackal/xmpp" +) + +// Private defines operations for private storage. +type Private interface { + // FetchPrivateXML retrieves from storage a private element. + FetchPrivateXML(ctx context.Context, namespace string, username string) ([]xmpp.XElement, error) + + // UpsertPrivateXML inserts a new private element into storage, or updates it if previously inserted. + UpsertPrivateXML(ctx context.Context, privateXML []xmpp.XElement, namespace string, username string) error +} diff --git a/storage/repository/pubsub.go b/storage/repository/pubsub.go new file mode 100644 index 000000000..ddeaa856c --- /dev/null +++ b/storage/repository/pubsub.go @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + + pubsubmodel "github.com/ortuman/jackal/model/pubsub" +) + +// PubSub defines storage operations for pubsub management. +type PubSub interface { + // FetchHosts returns all host identifiers. + FetchHosts(ctx context.Context) (hosts []string, err error) + + // UpsertNode inserts a new pubsub node entity into storage, or updates it if previously inserted. + UpsertNode(ctx context.Context, node *pubsubmodel.Node) error + + // FetchNode retrieves from storage a pubsub node entity. + FetchNode(ctx context.Context, host, name string) (*pubsubmodel.Node, error) + + // FetchNodes retrieves from storage all node entities associated with a host. + FetchNodes(ctx context.Context, host string) ([]pubsubmodel.Node, error) + + // FetchSubscribedNodes retrieves from storage all nodes to which a given jid is subscribed. + FetchSubscribedNodes(ctx context.Context, jid string) ([]pubsubmodel.Node, error) + + // DeleteNode deletes a pubsub node from storage. + DeleteNode(ctx context.Context, host, name string) error + + // UpsertNodeItem inserts a new pubsub node item entity into storage, or updates it if previously inserted. + UpsertNodeItem(ctx context.Context, item *pubsubmodel.Item, host, name string, maxNodeItems int) error + + // FetchNodeItems retrieves all items associated to a node. + FetchNodeItems(ctx context.Context, host, name string) ([]pubsubmodel.Item, error) + + // FetchNodeItemsWithIDs retrieves all items matching any of the passed identifiers. + FetchNodeItemsWithIDs(ctx context.Context, host, name string, identifiers []string) ([]pubsubmodel.Item, error) + + // FetchNodeLastItem retrieves last published node item. + FetchNodeLastItem(ctx context.Context, host, name string) (*pubsubmodel.Item, error) + + // UpsertNodeAffiliation inserts a new pubsub node affiliation into storage, or updates it if previously inserted. + UpsertNodeAffiliation(ctx context.Context, affiliation *pubsubmodel.Affiliation, host, name string) error + + // FetchNodeAffiliation retrieves a concrete node affiliation from storage. + FetchNodeAffiliation(ctx context.Context, host, name, jid string) (*pubsubmodel.Affiliation, error) + + // FetchNodeAffiliations retrieves all affiliations associated to a node. + FetchNodeAffiliations(ctx context.Context, host, name string) ([]pubsubmodel.Affiliation, error) + + // DeleteNodeAffiliation deletes a pubsub node affiliation from storage. + DeleteNodeAffiliation(ctx context.Context, jid, host, name string) error + + // UpsertNodeSubscription inserts a new pubsub node subscription into storage, or updates it if previously inserted. + UpsertNodeSubscription(ctx context.Context, subscription *pubsubmodel.Subscription, host, name string) error + + // FetchNodeSubscriptions retrieves all subscriptions associated to a node. + FetchNodeSubscriptions(ctx context.Context, host, name string) ([]pubsubmodel.Subscription, error) + + // DeleteNodeSubscription deletes a pubsub node subscription from storage. + DeleteNodeSubscription(ctx context.Context, jid, host, name string) error +} diff --git a/storage/repository/room.go b/storage/repository/room.go new file mode 100644 index 000000000..399078ab9 --- /dev/null +++ b/storage/repository/room.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + mucmodel "github.com/ortuman/jackal/model/muc" + "github.com/ortuman/jackal/xmpp/jid" +) + +// Room defines room repository operations +type Room interface { + // UpsertRoom inserts a new room entity into storage, or updates it if previously inserted. + UpsertRoom(ctx context.Context, room *mucmodel.Room) error + + // DeleteRoom deletes a room entity from storage. + DeleteRoom(ctx context.Context, roomJID *jid.JID) error + + // FetchRoom retrieves a room entity from storage. + FetchRoom(ctx context.Context, roomJID *jid.JID) (*mucmodel.Room, error) + + // RoomExists tells whether or not a room exists within storage. + RoomExists(ctx context.Context, roomJID *jid.JID) (bool, error) +} diff --git a/storage/repository/roster.go b/storage/repository/roster.go new file mode 100644 index 000000000..2510d054f --- /dev/null +++ b/storage/repository/roster.go @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + + rostermodel "github.com/ortuman/jackal/model/roster" +) + +// Roster defines storage operations for user's roster. +type Roster interface { + // UpsertRosterItem inserts a new roster item entity into storage, or updates it if previously inserted. + UpsertRosterItem(ctx context.Context, ri *rostermodel.Item) (rostermodel.Version, error) + + // DeleteRosterItem deletes a roster item entity from storage. + DeleteRosterItem(ctx context.Context, username, jid string) (rostermodel.Version, error) + + // FetchRosterItems retrieves from storage all roster item entities associated to a given user. + FetchRosterItems(ctx context.Context, username string) ([]rostermodel.Item, rostermodel.Version, error) + + // FetchRosterItemsInGroups retrieves from storage all roster item entities associated to a given user and a set of groups. + FetchRosterItemsInGroups(ctx context.Context, username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) + + // FetchRosterItem retrieves from storage a roster item entity. + FetchRosterItem(ctx context.Context, username, jid string) (*rostermodel.Item, error) + + // UpsertRosterNotification inserts a new roster notification entity into storage, or updates it if previously inserted. + UpsertRosterNotification(ctx context.Context, rn *rostermodel.Notification) error + + // DeleteRosterNotification deletes a roster notification entity from storage. + DeleteRosterNotification(ctx context.Context, contact, jid string) error + + // FetchRosterNotification retrieves from storage a roster notification entity. + FetchRosterNotification(ctx context.Context, contact string, jid string) (*rostermodel.Notification, error) + + // FetchRosterNotifications retrieves from storage all roster notifications associated to a given user. + FetchRosterNotifications(ctx context.Context, contact string) ([]rostermodel.Notification, error) + + // FetchRosterGroups retrieves all groups associated to a user roster. + FetchRosterGroups(ctx context.Context, username string) ([]string, error) +} diff --git a/storage/repository/user.go b/storage/repository/user.go new file mode 100644 index 000000000..a38284c28 --- /dev/null +++ b/storage/repository/user.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + + "github.com/ortuman/jackal/model" +) + +// User defines user repository operations +type User interface { + // UpsertUser inserts a new user entity into storage, or updates it if previously inserted. + UpsertUser(ctx context.Context, user *model.User) error + + // DeleteUser deletes a user entity from storage. + DeleteUser(ctx context.Context, username string) error + + // FetchUser retrieves a user entity from storage. + FetchUser(ctx context.Context, username string) (*model.User, error) + + // UserExists tells whether or not a user exists within storage. + UserExists(ctx context.Context, username string) (bool, error) +} diff --git a/storage/repository/vcard.go b/storage/repository/vcard.go new file mode 100644 index 000000000..38d350f2e --- /dev/null +++ b/storage/repository/vcard.go @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package repository + +import ( + "context" + + "github.com/ortuman/jackal/xmpp" +) + +// VCard defines storage operations for vCards +type VCard interface { + // UpsertVCard inserts a new vCard element into storage, or updates it in case it's been previously inserted. + UpsertVCard(ctx context.Context, vCard xmpp.XElement, username string) error + + // FetchVCard retrieves from storage a vCard element associated to a given user. + FetchVCard(ctx context.Context, username string) (xmpp.XElement, error) +} diff --git a/storage/roster.go b/storage/roster.go deleted file mode 100644 index 9a0468da5..000000000 --- a/storage/roster.go +++ /dev/null @@ -1,66 +0,0 @@ -package storage - -import "github.com/ortuman/jackal/model/rostermodel" - -// rosterStorage defines storage oprations for user's roster -type rosterStorage interface { - InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Version, error) - DeleteRosterItem(username, jid string) (rostermodel.Version, error) - FetchRosterItems(username string) ([]rostermodel.Item, rostermodel.Version, error) - FetchRosterItemsInGroups(username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) - FetchRosterItem(username, jid string) (*rostermodel.Item, error) - InsertOrUpdateRosterNotification(rn *rostermodel.Notification) error - DeleteRosterNotification(contact, jid string) error - FetchRosterNotification(contact string, jid string) (*rostermodel.Notification, error) - FetchRosterNotifications(contact string) ([]rostermodel.Notification, error) -} - -// InsertOrUpdateRosterItem inserts a new roster item entity into storage, -// or updates it in case it's been previously inserted. -func InsertOrUpdateRosterItem(ri *rostermodel.Item) (rostermodel.Version, error) { - return instance().InsertOrUpdateRosterItem(ri) -} - -// DeleteRosterItem deletes a roster item entity from storage. -func DeleteRosterItem(username, jid string) (rostermodel.Version, error) { - return instance().DeleteRosterItem(username, jid) -} - -// FetchRosterItems retrieves from storage all roster item entities -// associated to a given user. -func FetchRosterItems(username string) ([]rostermodel.Item, rostermodel.Version, error) { - return instance().FetchRosterItems(username) -} - -// FetchRosterItemsInGroups retrieves from storage all roster item entities -// associated to a given user and a set of groups. -func FetchRosterItemsInGroups(username string, groups []string) ([]rostermodel.Item, rostermodel.Version, error) { - return instance().FetchRosterItemsInGroups(username, groups) -} - -// FetchRosterItem retrieves from storage a roster item entity. -func FetchRosterItem(username, jid string) (*rostermodel.Item, error) { - return instance().FetchRosterItem(username, jid) -} - -// InsertOrUpdateRosterNotification inserts a new roster notification entity -// into storage, or updates it in case it's been previously inserted. -func InsertOrUpdateRosterNotification(rn *rostermodel.Notification) error { - return instance().InsertOrUpdateRosterNotification(rn) -} - -// DeleteRosterNotification deletes a roster notification entity from storage. -func DeleteRosterNotification(contact, jid string) error { - return instance().DeleteRosterNotification(contact, jid) -} - -// FetchRosterNotification retrieves from storage a roster notification entity. -func FetchRosterNotification(contact string, jid string) (*rostermodel.Notification, error) { - return instance().FetchRosterNotification(contact, jid) -} - -// FetchRosterNotifications retrieves from storage all roster notifications -// associated to a given user. -func FetchRosterNotifications(contact string) ([]rostermodel.Notification, error) { - return instance().FetchRosterNotifications(contact) -} diff --git a/storage/storage.go b/storage/storage.go index cd482b9ab..800253213 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. * See the LICENSE file for more information. */ @@ -7,78 +7,23 @@ package storage import ( "fmt" - "sync" - "github.com/ortuman/jackal/storage/badgerdb" - "github.com/ortuman/jackal/storage/memstorage" + memorystorage "github.com/ortuman/jackal/storage/memory" "github.com/ortuman/jackal/storage/mysql" "github.com/ortuman/jackal/storage/pgsql" + "github.com/ortuman/jackal/storage/repository" ) -// Storage represents an entity storage interface. -type Storage interface { - Close() error - - IsClusterCompatible() bool - - userStorage - offlineStorage - rosterStorage - vCardStorage - privateStorage - blockListStorage -} - -var ( - instMu sync.RWMutex - inst Storage -) - -// Disabled stores a disabled storage instance. -var Disabled Storage = &disabledStorage{} - -func init() { - inst = Disabled -} - -// New initializes storage sub system. -func New(config *Config) (Storage, error) { +// New initializes configured storage type and returns associated container. +func New(config *Config) (repository.Container, error) { switch config.Type { - case BadgerDB: - return badgerdb.New(config.BadgerDB), nil case MySQL: - return mysql.New(config.MySQL), nil + return mysql.New(config.MySQL) case PostgreSQL: - return pgsql.New(config.PostgreSQL), nil + return pgsql.New(config.PostgreSQL) case Memory: - return memstorage.New(), nil + return memorystorage.New() default: return nil, fmt.Errorf("storage: unrecognized storage type: %d", config.Type) } } - -// Set sets the global storage. -func Set(storage Storage) { - instMu.Lock() - _ = inst.Close() - inst = storage - instMu.Unlock() -} - -// Unset disables a previously set global storage. -func Unset() { - Set(Disabled) -} - -// IsClusterCompatible returns whether or not the underlying storage subsystem can be used in cluster mode. -func IsClusterCompatible() bool { - return instance().IsClusterCompatible() -} - -// instance returns a singleton instance of the storage subsystem -func instance() Storage { - instMu.RLock() - s := inst - instMu.RUnlock() - return s -} diff --git a/storage/user.go b/storage/user.go deleted file mode 100644 index 08d180393..000000000 --- a/storage/user.go +++ /dev/null @@ -1,32 +0,0 @@ -package storage - -import "github.com/ortuman/jackal/model" - -// userStorage defines storage operations for users -type userStorage interface { - InsertOrUpdateUser(user *model.User) error - DeleteUser(username string) error - FetchUser(username string) (*model.User, error) - UserExists(username string) (bool, error) -} - -// InsertOrUpdateUser inserts a new user entity into storage, -// or updates it in case it's been previously inserted. -func InsertOrUpdateUser(user *model.User) error { - return instance().InsertOrUpdateUser(user) -} - -// DeleteUser deletes a user entity from storage. -func DeleteUser(username string) error { - return instance().DeleteUser(username) -} - -// FetchUser retrieves from storage a user entity. -func FetchUser(username string) (*model.User, error) { - return instance().FetchUser(username) -} - -// UserExists returns whether or not a user exists within storage. -func UserExists(username string) (bool, error) { - return instance().UserExists(username) -} diff --git a/storage/vcard.go b/storage/vcard.go deleted file mode 100644 index d2e71f0c2..000000000 --- a/storage/vcard.go +++ /dev/null @@ -1,21 +0,0 @@ -package storage - -import "github.com/ortuman/jackal/xmpp" - -// vCardStorage defines storage operations for vCards -type vCardStorage interface { - InsertOrUpdateVCard(vCard xmpp.XElement, username string) error - FetchVCard(username string) (xmpp.XElement, error) -} - -// InsertOrUpdateVCard inserts a new vCard element into storage, -// or updates it in case it's been previously inserted. -func InsertOrUpdateVCard(vCard xmpp.XElement, username string) error { - return instance().InsertOrUpdateVCard(vCard, username) -} - -// FetchVCard retrieves from storage a vCard element associated -// to a given user. -func FetchVCard(username string) (xmpp.XElement, error) { - return instance().FetchVCard(username) -} diff --git a/stream/mocked.go b/stream/mocked.go new file mode 100644 index 000000000..bda686184 --- /dev/null +++ b/stream/mocked.go @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + +package stream + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/ortuman/jackal/xmpp" + "github.com/ortuman/jackal/xmpp/jid" +) + +// MockC2S represents a mocked c2s stream. +type MockC2S struct { + id string + mu sync.RWMutex + isSecured bool + isAuthenticated bool + isCompressed bool + isDisconnected bool + jid *jid.JID + presence *xmpp.Presence + elemCh chan xmpp.XElement + actorCh chan func() + discCh chan error + ctx context.Context +} + +// NewMockC2S returns a new mocked stream instance. +func NewMockC2S(id string, jid *jid.JID) *MockC2S { + stm := &MockC2S{ + id: id, + ctx: context.Background(), + elemCh: make(chan xmpp.XElement, 16), + actorCh: make(chan func(), 64), + discCh: make(chan error, 1), + } + stm.SetJID(jid) + go stm.actorLoop() + return stm +} + +// ID returns mocked stream identifier. +func (m *MockC2S) ID() string { + return m.id +} + +func (m *MockC2S) Context() context.Context { + m.mu.RLock() + defer m.mu.RUnlock() + return m.ctx +} + +func (m *MockC2S) Value(key interface{}) interface{} { + m.mu.RLock() + defer m.mu.RUnlock() + return m.ctx.Value(key) +} + +func (m *MockC2S) SetValue(key, value interface{}) { + m.mu.Lock() + defer m.mu.Unlock() + m.ctx = context.WithValue(m.ctx, key, value) +} + +// Username returns current mocked stream username. +func (m *MockC2S) Username() string { + return m.JID().Node() +} + +// Domain returns current mocked stream domain. +func (m *MockC2S) Domain() string { + return m.JID().Domain() +} + +// Resource returns current mocked stream resource. +func (m *MockC2S) Resource() string { + return m.JID().Resource() +} + +// SetJID sets the mocked stream JID value. +func (m *MockC2S) SetJID(jid *jid.JID) { + m.mu.Lock() + defer m.mu.Unlock() + m.jid = jid +} + +// JID returns current user JID. +func (m *MockC2S) JID() *jid.JID { + m.mu.RLock() + defer m.mu.RUnlock() + return m.jid +} + +// SetSecured sets whether or not the a mocked stream +// has been secured. +func (m *MockC2S) SetSecured(secured bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.isSecured = secured +} + +// IsSecured returns whether or not the mocked stream +// has been secured. +func (m *MockC2S) IsSecured() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.isSecured +} + +// SetAuthenticated sets whether or not the a mocked stream +// has been authenticated. +func (m *MockC2S) SetAuthenticated(authenticated bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.isAuthenticated = authenticated +} + +// IsAuthenticated returns whether or not the mocked stream +// has successfully authenticated. +func (m *MockC2S) IsAuthenticated() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.isAuthenticated +} + +// IsDisconnected returns whether or not the mocked stream has been disconnected. +func (m *MockC2S) IsDisconnected() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.isDisconnected +} + +// SetPresence sets the mocked stream last received +// presence element. +func (m *MockC2S) SetPresence(presence *xmpp.Presence) { + m.mu.Lock() + defer m.mu.Unlock() + m.presence = presence +} + +// Presence returns last sent presence element. +func (m *MockC2S) Presence() *xmpp.Presence { + m.mu.RLock() + defer m.mu.RUnlock() + return m.presence +} + +// SendElement sends the given XML element. +func (m *MockC2S) SendElement(_ context.Context, elem xmpp.XElement) { + m.actorCh <- func() { + m.sendElement(elem) + } +} + +// Disconnect disconnects mocked stream. +func (m *MockC2S) Disconnect(_ context.Context, err error) { + waitCh := make(chan struct{}) + m.actorCh <- func() { + m.disconnect(err) + close(waitCh) + } + <-waitCh +} + +// ReceiveElement waits until a new XML element is sent to +// the mocked stream and returns it. +func (m *MockC2S) ReceiveElement() xmpp.XElement { + select { + case e := <-m.elemCh: + return e + case <-time.After(time.Second * 5): + return &xmpp.Element{} + } +} + +// WaitDisconnection waits until the mocked stream disconnects. +func (m *MockC2S) WaitDisconnection() error { + select { + case err := <-m.discCh: + return err + case <-time.After(time.Second * 5): + return errors.New("operation timed out") + } +} + +func (m *MockC2S) actorLoop() { + for { + select { + case f := <-m.actorCh: + f() + case <-m.discCh: + return + } + } +} + +func (m *MockC2S) sendElement(elem xmpp.XElement) { + select { + case m.elemCh <- elem: + return + default: + break + } +} + +func (m *MockC2S) disconnect(err error) { + m.mu.Lock() + defer m.mu.Unlock() + if !m.isDisconnected { + m.discCh <- err + m.isDisconnected = true + } +} diff --git a/stream/stream.go b/stream/stream.go index b0ae96e38..e2aa316b5 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -6,9 +6,7 @@ package stream import ( - "errors" - "sync" - "time" + "context" "github.com/ortuman/jackal/xmpp" "github.com/ortuman/jackal/xmpp/jid" @@ -17,32 +15,23 @@ import ( // InStream represents a generic incoming stream. type InStream interface { ID() string - Disconnect(err error) + Disconnect(ctx context.Context, err error) } // InOutStream represents a generic incoming/outgoing stream. type InOutStream interface { InStream - SendElement(elem xmpp.XElement) + SendElement(ctx context.Context, elem xmpp.XElement) } // C2S represents a client-to-server XMPP stream. type C2S interface { InOutStream - Context() map[string]interface{} + Context() context.Context - SetString(key string, value string) - GetString(key string) string - - SetInt(key string, value int) - GetInt(key string) int - - SetFloat(key string, value float64) - GetFloat(key string) float64 - - SetBool(key string, value bool) - GetBool(key string) bool + SetValue(key, value interface{}) + Value(key interface{}) interface{} Username() string Domain() string @@ -65,270 +54,3 @@ type S2SIn interface { type S2SOut interface { InOutStream } - -// MockC2S represents a mocked c2s stream. -type MockC2S struct { - id string - mu sync.RWMutex - isSecured bool - isAuthenticated bool - isCompressed bool - isDisconnected bool - jid *jid.JID - presence *xmpp.Presence - contextMu sync.RWMutex - context map[string]interface{} - elemCh chan xmpp.XElement - actorCh chan func() - discCh chan error -} - -// NewMockC2S returns a new mocked stream instance. -func NewMockC2S(id string, jid *jid.JID) *MockC2S { - stm := &MockC2S{ - id: id, - context: make(map[string]interface{}), - elemCh: make(chan xmpp.XElement, 16), - actorCh: make(chan func(), 64), - discCh: make(chan error, 1), - } - stm.SetJID(jid) - go stm.actorLoop() - return stm -} - -// ID returns mocked stream identifier. -func (m *MockC2S) ID() string { - return m.id -} - -// Context returns a copy of the stream associated context. -func (m *MockC2S) Context() map[string]interface{} { - ret := make(map[string]interface{}) - m.contextMu.RLock() - for k, v := range m.context { - ret[k] = v - } - m.contextMu.RUnlock() - return ret -} - -// SetString associates a string context value to a key. -func (m *MockC2S) SetString(key string, value string) { - m.setContextValue(key, value) -} - -// GetString returns the context value associated with the key as a string. -func (m *MockC2S) GetString(key string) string { - var ret string - m.contextMu.RLock() - defer m.contextMu.RUnlock() - if s, ok := m.context[key].(string); ok { - ret = s - } - return ret -} - -// SetInt associates an integer context value to a key. -func (m *MockC2S) SetInt(key string, value int) { - m.setContextValue(key, value) -} - -// GetInt returns the context value associated with the key as an integer. -func (m *MockC2S) GetInt(key string) int { - var ret int - m.contextMu.RLock() - defer m.contextMu.RUnlock() - if i, ok := m.context[key].(int); ok { - ret = i - } - return ret -} - -// SetFloat associates a float context value to a key. -func (m *MockC2S) SetFloat(key string, value float64) { - m.setContextValue(key, value) -} - -// GetFloat returns the context value associated with the key as a float64. -func (m *MockC2S) GetFloat(key string) float64 { - var ret float64 - m.contextMu.RLock() - defer m.contextMu.RUnlock() - if f, ok := m.context[key].(float64); ok { - ret = f - } - return ret -} - -// SetBool associates a boolean context value to a key. -func (m *MockC2S) SetBool(key string, value bool) { - m.setContextValue(key, value) -} - -// GetBool returns the context value associated with the key as a boolean. -func (m *MockC2S) GetBool(key string) bool { - var ret bool - m.contextMu.RLock() - defer m.contextMu.RUnlock() - if b, ok := m.context[key].(bool); ok { - ret = b - } - return ret -} - -// Username returns current mocked stream username. -func (m *MockC2S) Username() string { - return m.JID().Node() -} - -// Domain returns current mocked stream domain. -func (m *MockC2S) Domain() string { - return m.JID().Domain() -} - -// Resource returns current mocked stream resource. -func (m *MockC2S) Resource() string { - return m.JID().Resource() -} - -// SetJID sets the mocked stream JID value. -func (m *MockC2S) SetJID(jid *jid.JID) { - m.mu.Lock() - defer m.mu.Unlock() - m.jid = jid -} - -// JID returns current user JID. -func (m *MockC2S) JID() *jid.JID { - m.mu.RLock() - defer m.mu.RUnlock() - return m.jid -} - -// SetSecured sets whether or not the a mocked stream -// has been secured. -func (m *MockC2S) SetSecured(secured bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.isSecured = secured -} - -// IsSecured returns whether or not the mocked stream -// has been secured. -func (m *MockC2S) IsSecured() bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.isSecured -} - -// SetAuthenticated sets whether or not the a mocked stream -// has been authenticated. -func (m *MockC2S) SetAuthenticated(authenticated bool) { - m.mu.Lock() - defer m.mu.Unlock() - m.isAuthenticated = authenticated -} - -// IsAuthenticated returns whether or not the mocked stream -// has successfully authenticated. -func (m *MockC2S) IsAuthenticated() bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.isAuthenticated -} - -// IsDisconnected returns whether or not the mocked stream has been disconnected. -func (m *MockC2S) IsDisconnected() bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.isDisconnected -} - -// SetPresence sets the mocked stream last received -// presence element. -func (m *MockC2S) SetPresence(presence *xmpp.Presence) { - m.mu.Lock() - defer m.mu.Unlock() - m.presence = presence -} - -// Presence returns last sent presence element. -func (m *MockC2S) Presence() *xmpp.Presence { - m.mu.RLock() - defer m.mu.RUnlock() - return m.presence -} - -// SendElement sends the given XML element. -func (m *MockC2S) SendElement(elem xmpp.XElement) { - m.actorCh <- func() { - m.sendElement(elem) - } -} - -// Disconnect disconnects mocked stream. -func (m *MockC2S) Disconnect(err error) { - waitCh := make(chan struct{}) - m.actorCh <- func() { - m.disconnect(err) - close(waitCh) - } - <-waitCh -} - -// ReceiveElement waits until a new XML element is sent to -// the mocked stream and returns it. -func (m *MockC2S) ReceiveElement() xmpp.XElement { - select { - case e := <-m.elemCh: - return e - case <-time.After(time.Second * 5): - return &xmpp.Element{} - } -} - -// WaitDisconnection waits until the mocked stream disconnects. -func (m *MockC2S) WaitDisconnection() error { - select { - case err := <-m.discCh: - return err - case <-time.After(time.Second * 5): - return errors.New("operation timed out") - } -} - -func (m *MockC2S) actorLoop() { - for { - select { - case f := <-m.actorCh: - f() - case <-m.discCh: - return - } - } -} - -func (m *MockC2S) sendElement(elem xmpp.XElement) { - select { - case m.elemCh <- elem: - return - default: - break - } -} - -func (m *MockC2S) disconnect(err error) { - m.mu.Lock() - defer m.mu.Unlock() - if !m.isDisconnected { - m.discCh <- err - m.isDisconnected = true - } -} - -func (m *MockC2S) setContextValue(key string, value interface{}) { - m.contextMu.Lock() - defer m.contextMu.Unlock() - m.context[key] = value -} diff --git a/stream/stream_test.go b/stream/stream_test.go index be00cd2ba..25b99d5ea 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -6,6 +6,7 @@ package stream import ( + "context" "testing" "github.com/ortuman/jackal/xmpp" @@ -35,12 +36,12 @@ func TestMockC2Stream(t *testing.T) { require.Equal(t, 1, len(presenceElements)) elem := xmpp.NewElementName("elem1234") - stm.SendElement(elem) + stm.SendElement(context.Background(), elem) fetch := stm.ReceiveElement() require.NotNil(t, fetch) require.Equal(t, "elem1234", fetch.Name()) - stm.Disconnect(nil) + stm.Disconnect(context.Background(), nil) require.True(t, stm.IsDisconnected()) stm.SetSecured(true) require.True(t, stm.IsSecured()) diff --git a/testdata/config_modules.yml b/testdata/config_modules.yml index 1708b1dba..5e0cde01a 100644 --- a/testdata/config_modules.yml +++ b/testdata/config_modules.yml @@ -10,6 +10,7 @@ enabled: - blocking_command - ping - offline + - muc mod_roster: versioning: true @@ -28,3 +29,21 @@ mod_version: mod_ping: send: no send_interval: 60 + +mod_muc: + host: conference.jackal.im + name: "Chatroom Server" + room_defaults: + public: true + persistent: true + password_protected: false + open: true + moderated: false + allow_invites: false + allow_subject_change: true + enable_logging: true + non_anonymous: true + occupant_count: -1 # -1 means don't set the limit + # options for the next ones are "all", "moderators" and "" + can_get_member_list: "all" + send_pm: "all" diff --git a/transport/quicsocket.go b/transport/quicsocket.go index 58019da1e..847779fe2 100644 --- a/transport/quicsocket.go +++ b/transport/quicsocket.go @@ -8,7 +8,6 @@ package transport import ( "bufio" "crypto/tls" - "time" "github.com/lucas-clemente/quic-go" ) @@ -19,14 +18,12 @@ type quicSocketTransport struct { } // NewQUICSocketTransport create and return a new quicSocketTransport. -func NewQUICSocketTransport(conn quic.Session, uniStream quic.Stream, - keepAlive time.Duration) Transport { +func NewQUICSocketTransport(conn quic.Session, uniStream quic.Stream) Transport { s := &quicSocketTransport{ socketTransport: socketTransport{ rw: uniStream, br: bufio.NewReaderSize(uniStream, socketBuffSize), bw: bufio.NewWriterSize(uniStream, socketBuffSize), - keepAlive: keepAlive, }, conn: conn, } diff --git a/transport/socket.go b/transport/socket.go index f39eb43a0..ba2c6e6c3 100644 --- a/transport/socket.go +++ b/transport/socket.go @@ -24,26 +24,21 @@ type socketTransport struct { rw io.ReadWriter br *bufio.Reader bw *bufio.Writer - keepAlive time.Duration compressed bool } // NewSocketTransport creates a socket class stream transport. -func NewSocketTransport(conn net.Conn, keepAlive time.Duration) Transport { +func NewSocketTransport(conn net.Conn) Transport { s := &socketTransport{ - conn: conn, - rw: conn, - br: bufio.NewReaderSize(conn, socketBuffSize), - bw: bufio.NewWriterSize(conn, socketBuffSize), - keepAlive: keepAlive, + conn: conn, + rw: conn, + br: bufio.NewReaderSize(conn, socketBuffSize), + bw: bufio.NewWriterSize(conn, socketBuffSize), } return s } func (s *socketTransport) Read(p []byte) (n int, err error) { - if s.keepAlive > 0 { - s.conn.SetReadDeadline(time.Now().Add(s.keepAlive)) - } return s.br.Read(p) } @@ -69,6 +64,11 @@ func (s *socketTransport) Flush() error { return s.bw.Flush() } +// SetWriteDeadline sets the deadline for future write calls. +func (s *socketTransport) SetWriteDeadline(d time.Time) error { + return s.conn.SetWriteDeadline(d) +} + func (s *socketTransport) StartTLS(cfg *tls.Config, asClient bool) { if _, ok := s.conn.(*net.TCPConn); ok { if asClient { diff --git a/transport/socket_test.go b/transport/socket_test.go index 461484854..62c073908 100644 --- a/transport/socket_test.go +++ b/transport/socket_test.go @@ -52,7 +52,7 @@ func (a fakeAddr) String() string { return "str" } func TestSocket(t *testing.T) { buff := make([]byte, 4096) conn := newFakeSocketConn() - st := NewSocketTransport(conn, 4096) + st := NewSocketTransport(conn) st2 := st.(*socketTransport) el1 := xmpp.NewElementNamespace("elem", "exodus:ns") diff --git a/transport/transport.go b/transport/transport.go index c4d002086..1acd9fd26 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "crypto/x509" "io" + "time" "github.com/ortuman/jackal/transport/compress" ) @@ -19,9 +20,6 @@ type Type int const ( // Socket represents a socket transport type. Socket Type = iota + 1 - - // WebSocket represents a websocket transport type. - WebSocket ) // String returns TransportType string representation. @@ -29,8 +27,6 @@ func (tt Type) String() string { switch tt { case Socket: return "socket" - case WebSocket: - return "websocket" } return "" } @@ -56,19 +52,19 @@ type Transport interface { // Flush writes any buffered data to the underlying io.Writer. Flush() error + // SetWriteDeadline sets the deadline for future write calls. + SetWriteDeadline(d time.Time) error + // StartTLS secures the transport using SSL/TLS StartTLS(cfg *tls.Config, asClient bool) - // EnableCompression activates a compression - // mechanism on the transport. + // EnableCompression activates a compression mechanism on the transport. EnableCompression(compress.Level) - // ChannelBindingBytes returns current transport - // channel binding bytes. + // ChannelBindingBytes returns current transport channel binding bytes. ChannelBindingBytes(ChannelBindingMechanism) []byte - // PeerCertificates returns the certificate chain - // presented by remote peer. + // PeerCertificates returns the certificate chain presented by remote peer. PeerCertificates() []*x509.Certificate } diff --git a/transport/websocket.go b/transport/websocket.go deleted file mode 100644 index ca5ef91b5..000000000 --- a/transport/websocket.go +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package transport - -import ( - "bytes" - "crypto/tls" - "crypto/x509" - "io" - "net" - "strings" - "time" - - "github.com/gorilla/websocket" - "github.com/ortuman/jackal/transport/compress" -) - -// WebSocketConn represents a websocket connection interface. -type WebSocketConn interface { - NextReader() (messageType int, r io.Reader, err error) - NextWriter(int) (io.WriteCloser, error) - Close() error - UnderlyingConn() net.Conn - SetReadDeadline(t time.Time) error -} - -type webSocketTransport struct { - conn WebSocketConn - r *bytes.Reader - keepAlive time.Duration -} - -// NewWebSocketTransport creates a socket class stream transport. -func NewWebSocketTransport(conn WebSocketConn, keepAlive time.Duration) Transport { - wst := &webSocketTransport{ - conn: conn, - keepAlive: keepAlive, - } - return wst -} - -func (wst *webSocketTransport) Read(p []byte) (n int, err error) { - _, r, err := wst.conn.NextReader() - if err != nil { - return 0, err - } - if wst.keepAlive > 0 { - wst.conn.SetReadDeadline(time.Now().Add(wst.keepAlive)) - } - return r.Read(p) -} - -func (wst *webSocketTransport) Write(p []byte) (n int, err error) { - w, err := wst.conn.NextWriter(websocket.TextMessage) - if err != nil { - return 0, err - } - defer w.Close() - return w.Write(p) -} - -func (wst *webSocketTransport) Close() error { - return wst.conn.Close() -} - -func (wst *webSocketTransport) Type() Type { - return WebSocket -} - -func (wst *webSocketTransport) WriteString(str string) (int, error) { - w, err := wst.conn.NextWriter(websocket.TextMessage) - if err != nil { - return 0, err - } - defer w.Close() - n, err := io.Copy(w, strings.NewReader(str)) - return int(n), err -} - -// Flush writes any buffered data to the underlying io.Writer. -func (wst *webSocketTransport) Flush() error { - return nil -} - -func (wst *webSocketTransport) StartTLS(_ *tls.Config, _ bool) { -} - -func (wst *webSocketTransport) EnableCompression(level compress.Level) { -} - -func (wst *webSocketTransport) ChannelBindingBytes(mechanism ChannelBindingMechanism) []byte { - if tlsConn, ok := wst.conn.UnderlyingConn().(tlsStateQueryable); ok { - switch mechanism { - case TLSUnique: - st := tlsConn.ConnectionState() - return st.TLSUnique - default: - break - } - } - return nil -} - -func (wst *webSocketTransport) PeerCertificates() []*x509.Certificate { - if tlsConn, ok := wst.conn.UnderlyingConn().(tlsStateQueryable); ok { - st := tlsConn.ConnectionState() - return st.PeerCertificates - } - return nil -} diff --git a/transport/websocket_test.go b/transport/websocket_test.go deleted file mode 100644 index 5f22363a2..000000000 --- a/transport/websocket_test.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package transport - -import ( - "bytes" - "crypto/tls" - "io" - "net" - "testing" - "time" - - "github.com/ortuman/jackal/xmpp" - "github.com/pborman/uuid" - "github.com/stretchr/testify/require" -) - -type fakeWebSocketReader struct { - buf *bytes.Buffer -} - -func (r *fakeWebSocketReader) Read(p []byte) (n int, err error) { return r.buf.Read(p) } - -type fakeWebSocketWriter struct { - buf *bytes.Buffer -} - -func (w *fakeWebSocketWriter) Write(p []byte) (n int, err error) { return w.buf.Write(p) } -func (w *fakeWebSocketWriter) Close() error { return nil } - -type fakeWebSocketConn struct { - r *fakeWebSocketReader - w *fakeWebSocketWriter - closed bool -} - -func newFakeWebSocketConn() *fakeWebSocketConn { - return &fakeWebSocketConn{ - r: &fakeWebSocketReader{buf: new(bytes.Buffer)}, - w: &fakeWebSocketWriter{buf: new(bytes.Buffer)}, - } -} - -func (c *fakeWebSocketConn) NextReader() (messageType int, r io.Reader, err error) { return 0, c.r, nil } -func (c *fakeWebSocketConn) NextWriter(int) (writer io.WriteCloser, err error) { return c.w, nil } -func (c *fakeWebSocketConn) Close() error { c.closed = true; return nil } -func (c *fakeWebSocketConn) SetReadDeadline(t time.Time) error { return nil } -func (c *fakeWebSocketConn) UnderlyingConn() net.Conn { return &tls.Conn{} } - -func TestWebSocketTransport(t *testing.T) { - buff := make([]byte, 4096) - conn := newFakeWebSocketConn() - - // test read... - iq := xmpp.NewIQType(uuid.New(), xmpp.ResultType) - iq.SetFrom("localhost") - iq.ToXML(conn.r.buf, true) - - wst := NewWebSocketTransport(conn, 120) - n, err := wst.Read(buff) - require.Nil(t, err) - require.Equal(t, iq.String(), string(buff[:n])) - - // test write... - msg := xmpp.NewMessageType(uuid.New(), xmpp.ChatType) - b := xmpp.NewElementName("body") - b.SetText("Hi buddy!") - msg.AppendElement(b) - - io.WriteString(wst, msg.String()) - require.Equal(t, msg.String(), conn.w.buf.String()) - conn.w.buf.Reset() - - msg.ToXML(wst, true) - require.Equal(t, msg.String(), conn.w.buf.String()) - - require.Nil(t, wst.ChannelBindingBytes(ChannelBindingMechanism(99))) - require.Nil(t, wst.ChannelBindingBytes(TLSUnique)) - - wst.Close() - require.True(t, conn.closed) -} diff --git a/pool/buffer.go b/util/pool/buffer.go similarity index 100% rename from pool/buffer.go rename to util/pool/buffer.go diff --git a/pool/buffer_test.go b/util/pool/buffer_test.go similarity index 72% rename from pool/buffer_test.go rename to util/pool/buffer_test.go index e080fb20e..e6a0e2cd9 100644 --- a/pool/buffer_test.go +++ b/util/pool/buffer_test.go @@ -6,10 +6,10 @@ package pool import ( + "math/rand" "reflect" "testing" - "github.com/ortuman/jackal/util" "github.com/stretchr/testify/require" ) @@ -22,7 +22,13 @@ func TestBufferPool_GetAndPut(t *testing.T) { require.Equal(t, "*bytes.Buffer", reflect.ValueOf(buf).Type().String()) buf = p.Get() - buf.Write(util.RandomBytes(randomBytesLength)) + + randomBytes := make([]byte, randomBytesLength) + _, err := rand.Read(randomBytes) + if err != nil { + t.Errorf("error reading random bytes: %v", err) + } + buf.Write(randomBytes) require.Equal(t, randomBytesLength, buf.Len()) p.Put(buf) buf = p.Get() diff --git a/util/rand.go b/util/rand.go deleted file mode 100644 index d92347343..000000000 --- a/util/rand.go +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package util - -import ( - "math/rand" - "time" -) - -func init() { - rand.Seed(time.Now().UTC().UnixNano()) -} - -// RandomBytes generates a random bytes slice of length 'len'. -func RandomBytes(len int) []byte { - b := make([]byte, len) - for i := 0; i < len; i++ { - b[i] = byte(rand.Intn(256)) - } - return b -} diff --git a/util/rand_test.go b/util/rand_test.go deleted file mode 100644 index 38b25007b..000000000 --- a/util/rand_test.go +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ɓngel OrtuƱo. - * See the LICENSE file for more information. - */ - -package util - -import ( - "encoding/hex" - "math/rand" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestRandomBytes(t *testing.T) { - rand.Seed(1234) - r1 := hex.EncodeToString(RandomBytes(16)) - - rand.Seed(3456) - r2 := hex.EncodeToString(RandomBytes(16)) - - require.Equal(t, 32, len(r1)) - require.Equal(t, 32, len(r2)) - require.Equal(t, "c28bed645434c46376369bc5cc400b4c", r1) - require.Equal(t, "067af84b676f17b0dac36bbaa455148a", r2) -} diff --git a/runqueue/mpsc/mpsc.go b/util/runqueue/mpsc/mpsc.go similarity index 95% rename from runqueue/mpsc/mpsc.go rename to util/runqueue/mpsc/mpsc.go index 1bb63a4c3..9cacd14e1 100644 --- a/runqueue/mpsc/mpsc.go +++ b/util/runqueue/mpsc/mpsc.go @@ -17,12 +17,14 @@ type node struct { val interface{} } +// Queue represents a lock-free MPSC queue. type Queue struct { head *node tail *node stub node } +// New returns an empty Queue. func New() *Queue { q := &Queue{} q.head = &q.stub diff --git a/runqueue/mpsc/mpsc_test.go b/util/runqueue/mpsc/mpsc_test.go similarity index 100% rename from runqueue/mpsc/mpsc_test.go rename to util/runqueue/mpsc/mpsc_test.go diff --git a/runqueue/runqueue.go b/util/runqueue/runqueue.go similarity index 67% rename from runqueue/runqueue.go rename to util/runqueue/runqueue.go index 0a4c4ae8b..b2b3914c6 100644 --- a/runqueue/runqueue.go +++ b/util/runqueue/runqueue.go @@ -1,10 +1,16 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + package runqueue import ( + "runtime" "sync/atomic" "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/runqueue/mpsc" + "github.com/ortuman/jackal/util/runqueue/mpsc" ) const ( @@ -12,6 +18,7 @@ const ( running ) +// RunQueue represents a lock-free operation queue. type RunQueue struct { name string queue *mpsc.Queue @@ -23,6 +30,7 @@ type RunQueue struct { type funcMessage struct{ fn func() } type stopMessage struct{ stopCb func() } +// New returns an initialized lock-free operation queue. func New(name string) *RunQueue { return &RunQueue{ name: name, @@ -30,6 +38,7 @@ func New(name string) *RunQueue { } } +// Run pushes a new operation function into the queue. func (m *RunQueue) Run(fn func()) { if atomic.LoadInt32(&m.stopped) == 1 { return @@ -39,6 +48,10 @@ func (m *RunQueue) Run(fn func()) { m.schedule() } +// Stop signals the queue to stop running. +// +// Callback function represented by 'stopCb' its guaranteed to be immediately executed only if no job has been +// previously scheduled. func (m *RunQueue) Stop(stopCb func()) { if atomic.CompareAndSwapInt32(&m.stopped, 0, 1) { if atomic.LoadInt32(&m.messageCount) > 0 { @@ -75,10 +88,9 @@ process: } func (m *RunQueue) run() { - defer func() { if err := recover(); err != nil { - log.Debugf("run queue %s panicked with error: %v", m.name, err) + m.logStackTrace(err) } }() @@ -97,3 +109,10 @@ func (m *RunQueue) run() { } } } + +func (m *RunQueue) logStackTrace(err interface{}) { + stackSlice := make([]byte, 4096) + s := runtime.Stack(stackSlice, false) + + log.Errorf("runqueue '%s' panicked with error: %v\n%s", m.name, err, stackSlice[0:s]) +} diff --git a/runqueue/runqueue_test.go b/util/runqueue/runqueue_test.go similarity index 87% rename from runqueue/runqueue_test.go rename to util/runqueue/runqueue_test.go index fd75b6823..e586bafcd 100644 --- a/runqueue/runqueue_test.go +++ b/util/runqueue/runqueue_test.go @@ -1,3 +1,8 @@ +/* + * Copyright (c) 2019 Miguel Ɓngel OrtuƱo. + * See the LICENSE file for more information. + */ + package runqueue import ( diff --git a/util/string.go b/util/string/string.go similarity index 95% rename from util/string.go rename to util/string/string.go index 3da8c23ca..7f70a8d4c 100644 --- a/util/string.go +++ b/util/string/string.go @@ -3,7 +3,7 @@ * See the LICENSE file for more information. */ -package util +package utilstring // SplitKeyAndValue splits a string between 'key' and 'value' sub elements. func SplitKeyAndValue(str string, sep byte) (key string, value string) { diff --git a/util/string_test.go b/util/string/string_test.go similarity index 95% rename from util/string_test.go rename to util/string/string_test.go index f95b37e93..cc310ffeb 100644 --- a/util/string_test.go +++ b/util/string/string_test.go @@ -3,7 +3,7 @@ * See the LICENSE file for more information. */ -package util +package utilstring import ( "testing" diff --git a/util/tls.go b/util/tls/tls.go similarity index 90% rename from util/tls.go rename to util/tls/tls.go index 32d3a77ba..53492cc41 100644 --- a/util/tls.go +++ b/util/tls/tls.go @@ -3,7 +3,7 @@ * See the LICENSE file for more information. */ -package util +package utiltls import ( "crypto/rand" @@ -93,17 +93,20 @@ func generateSelfSignedCertificate(keyFile, certFile, domain string) error { if err != nil { return err } - defer certOut.Close() - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + defer func() { _ = certOut.Close() }() + + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + return err + } // encode private key keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } - defer keyOut.Close() - pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) - return nil + defer func() { _ = keyOut.Close() }() + + return pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) } func selfSignedCertificateExists() bool { diff --git a/util/tls_test.go b/util/tls/tls_test.go similarity index 84% rename from util/tls_test.go rename to util/tls/tls_test.go index 742a71756..aa87123a7 100644 --- a/util/tls_test.go +++ b/util/tls/tls_test.go @@ -3,7 +3,7 @@ * See the LICENSE file for more information. */ -package util +package utiltls import ( "crypto/tls" @@ -15,7 +15,7 @@ import ( func TestLoadCertificate(t *testing.T) { t.Run("Valid", func(t *testing.T) { - tlsCfg, err := LoadCertificate("../testdata/cert/test.server.key", "../testdata/cert/test.server.crt", "localhost") + tlsCfg, err := LoadCertificate("../../testdata/cert/test.server.key", "../../testdata/cert/test.server.crt", "localhost") require.Nil(t, err) require.NotNil(t, tlsCfg) }) diff --git a/version/version.go b/version/version.go index 67d8e66b9..7b3810008 100644 --- a/version/version.go +++ b/version/version.go @@ -10,7 +10,7 @@ import ( ) // ApplicationVersion represents application version. -var ApplicationVersion = NewVersion(0, 5, 1) +var ApplicationVersion = NewVersion(0, 10, 1) // SemanticVersion represents version information with Semantic Versioning specifications. type SemanticVersion struct { diff --git a/xmpp/element.go b/xmpp/element.go index 66a72c57f..bc7bd6fef 100644 --- a/xmpp/element.go +++ b/xmpp/element.go @@ -155,50 +155,81 @@ func (e *Element) String() string { buf := bufPool.Get() defer bufPool.Put(buf) - e.ToXML(buf, true) + _ = e.ToXML(buf, true) return buf.String() } // ToXML serializes element to a raw XML representation. // includeClosing determines if closing tag should be attached. -func (e *Element) ToXML(w io.Writer, includeClosing bool) { - io.WriteString(w, "<") - io.WriteString(w, e.name) +func (e *Element) ToXML(w io.Writer, includeClosing bool) error { + if _, err := io.WriteString(w, "<"); err != nil { + return err + } + if _, err := io.WriteString(w, e.name); err != nil { + return err + } // serialize attributes for _, attr := range e.attrs { if len(attr.Value) == 0 { continue } - io.WriteString(w, " ") - io.WriteString(w, attr.Label) - io.WriteString(w, `="`) - io.WriteString(w, attr.Value) - io.WriteString(w, `"`) + if _, err := io.WriteString(w, ` `); err != nil { + return err + } + if _, err := io.WriteString(w, attr.Label); err != nil { + return err + } + if _, err := io.WriteString(w, `="`); err != nil { + return err + } + if _, err := io.WriteString(w, attr.Value); err != nil { + return err + } + if _, err := io.WriteString(w, `"`); err != nil { + return err + } } + // serialize elements if e.elements.Count() > 0 || len(e.text) > 0 { - io.WriteString(w, ">") - + if _, err := io.WriteString(w, ">"); err != nil { + return err + } if len(e.text) > 0 { - escapeText(w, []byte(e.text), false) + if err := escapeText(w, []byte(e.text), false); err != nil { + return err + } } for _, elem := range e.elements { - elem.ToXML(w, true) + if err := elem.ToXML(w, true); err != nil { + return err + } } if includeClosing { - io.WriteString(w, "") + if _, err := io.WriteString(w, ""); err != nil { + return err + } } } else { if includeClosing { - io.WriteString(w, "/>") + if _, err := io.WriteString(w, "/>"); err != nil { + return err + } } else { - io.WriteString(w, ">") + if _, err := io.WriteString(w, ">"); err != nil { + return err + } } } + return nil } // FromBytes deserializes an element node from it's gob binary representation. diff --git a/xmpp/element_set.go b/xmpp/element_set.go index 71390a1d3..babf24fe1 100644 --- a/xmpp/element_set.go +++ b/xmpp/element_set.go @@ -125,15 +125,17 @@ func (es *elementSet) FromBytes(buf *bytes.Buffer) error { if err := dec.Decode(&c); err != nil { return err } - set := make([]XElement, c) - for i := 0; i < c; i++ { - el, err := NewElementFromBytes(buf) - if err != nil { - return err + if c > 0 { + set := make([]XElement, c) + for i := 0; i < c; i++ { + el, err := NewElementFromBytes(buf) + if err != nil { + return err + } + set[i] = el } - set[i] = el + *es = set } - *es = set return nil } diff --git a/xmpp/element_test.go b/xmpp/element_test.go index 3c07467b1..d2d7a18ff 100644 --- a/xmpp/element_test.go +++ b/xmpp/element_test.go @@ -9,10 +9,8 @@ import ( "bytes" "testing" + "github.com/google/uuid" "github.com/ortuman/jackal/xmpp/jid" - - "github.com/pborman/uuid" - "github.com/stretchr/testify/require" ) @@ -28,7 +26,7 @@ func TestElement_NewElement(t *testing.T) { func TestElement_NewError(t *testing.T) { j, _ := jid.New("", "jackal.im", "", true) - e1 := NewIQType(uuid.New(), GetType) + e1 := NewIQType(uuid.New().String(), GetType) e1.SetFromJID(j) e1.SetToJID(j) e2 := NewErrorStanzaFromStanza(e1, ErrNotAuthorized, nil) @@ -69,15 +67,15 @@ func TestElement_ToXML(t *testing.T) { e1.AppendElement(NewElementName("a")) e1.AppendElement(NewElementName("b")) buf := new(bytes.Buffer) - e1.ToXML(buf, true) + _ = e1.ToXML(buf, true) require.Equal(t, `Hi!`, buf.String()) buf.Reset() e1.ClearElements() e1.SetText("") - e1.ToXML(buf, true) + _ = e1.ToXML(buf, true) require.Equal(t, ``, buf.String()) buf.Reset() - e1.ToXML(buf, false) + _ = e1.ToXML(buf, false) require.Equal(t, ``, buf.String()) } diff --git a/xmpp/error.go b/xmpp/error.go index 36940169b..ed7cea88d 100644 --- a/xmpp/error.go +++ b/xmpp/error.go @@ -69,6 +69,7 @@ const ( subscriptionRequiredErrorReason = "subscription-required" undefinedConditionErrorReason = "undefined-condition" unexpectedConditionErrorReason = "unexpected-condition" + unexpectedRequestErrorReason = "unexpected-request" ) var ( @@ -165,6 +166,10 @@ var ( // ErrUnexpectedCondition is returned by the stream when the recipient or server // understood the request but was not expecting it at this time. ErrUnexpectedCondition = newStanzaError(400, waitErrorType, unexpectedConditionErrorReason) + + // ErrUnexpectedRequest is returned by the stream when the recipient or server + // understood the request but was not expecting it at this time. + ErrUnexpectedRequest = newStanzaError(400, cancelErrorType, unexpectedRequestErrorReason) ) // BadRequestError returns an error copy of the element @@ -298,3 +303,9 @@ func (s *stanzaElement) UndefinedConditionError() Stanza { func (s *stanzaElement) UnexpectedConditionError() Stanza { return NewErrorStanzaFromStanza(s, ErrUnexpectedCondition, nil) } + +// UnexpectedRequestError returns an error copy of the element +// attaching 'unexpected-request' error sub element. +func (s *stanzaElement) UnexpectedRequestError() Stanza { + return NewErrorStanzaFromStanza(s, ErrUnexpectedRequest, nil) +} diff --git a/xmpp/error_test.go b/xmpp/error_test.go index b07cfdef7..452a8075c 100644 --- a/xmpp/error_test.go +++ b/xmpp/error_test.go @@ -34,6 +34,7 @@ func TestError(t *testing.T) { require.Equal(t, subscriptionRequiredErrorReason, ErrSubscriptionRequired.Error()) require.Equal(t, undefinedConditionErrorReason, ErrUndefinedCondition.Error()) require.Equal(t, unexpectedConditionErrorReason, ErrUnexpectedCondition.Error()) + require.Equal(t, unexpectedRequestErrorReason, ErrUnexpectedRequest.Error()) j, _ := jid.New("", "jackal.im", "", true) e := NewIQType(uuid.New(), GetType) @@ -62,4 +63,5 @@ func TestError(t *testing.T) { require.NotNil(t, e.SubscriptionRequiredError().Error().Elements().Child(subscriptionRequiredErrorReason)) require.NotNil(t, e.UndefinedConditionError().Error().Elements().Child(undefinedConditionErrorReason)) require.NotNil(t, e.UnexpectedConditionError().Error().Elements().Child(unexpectedConditionErrorReason)) + require.NotNil(t, e.UnexpectedRequestError().Error().Elements().Child(unexpectedRequestErrorReason)) } diff --git a/xmpp/jid/jid.go b/xmpp/jid/jid.go index ebca00516..6734cd617 100644 --- a/xmpp/jid/jid.go +++ b/xmpp/jid/jid.go @@ -13,7 +13,7 @@ import ( "strings" "unicode/utf8" - "github.com/ortuman/jackal/pool" + "github.com/ortuman/jackal/util/pool" "golang.org/x/net/idna" "golang.org/x/text/secure/precis" ) @@ -168,8 +168,20 @@ func (j *JID) IsFullWithUser() bool { return len(j.node) > 0 && len(j.resource) > 0 } -// Matches returns true if two JID's are equivalent. -func (j *JID) Matches(j2 *JID, options MatchingOptions) bool { +// Matches tells whether or not j2 matches j. +func (j *JID) Matches(j2 *JID) bool { + if j.IsFullWithUser() { + return j.MatchesWithOptions(j2, MatchesNode|MatchesDomain|MatchesResource) + } else if j.IsFullWithServer() { + return j.MatchesWithOptions(j2, MatchesDomain|MatchesResource) + } else if j.IsBare() { + return j.MatchesWithOptions(j2, MatchesNode|MatchesDomain) + } + return j.MatchesWithOptions(j2, MatchesDomain) +} + +// MatchesWithOptions tells whether two jids are equivalent based on matching options. +func (j *JID) MatchesWithOptions(j2 *JID, options MatchingOptions) bool { if (options&MatchesNode) > 0 && j.node != j2.node { return false } diff --git a/xmpp/jid/jid_test.go b/xmpp/jid/jid_test.go index 7f7e0694e..97a853932 100644 --- a/xmpp/jid/jid_test.go +++ b/xmpp/jid/jid_test.go @@ -87,15 +87,15 @@ func TestMatchesJID(t *testing.T) { j3, _ := jid.NewWithString("example.org", false) j4, _ := jid.NewWithString("example.org/res1", false) j6, _ := jid.NewWithString("ortuman@example2.org/res2", false) - require.True(t, j1.Matches(j1, jid.MatchesNode|jid.MatchesDomain|jid.MatchesResource)) - require.True(t, j1.Matches(j2, jid.MatchesNode|jid.MatchesDomain)) - require.True(t, j1.Matches(j3, jid.MatchesDomain)) - require.True(t, j1.Matches(j4, jid.MatchesDomain|jid.MatchesResource)) + require.True(t, j1.MatchesWithOptions(j1, jid.MatchesNode|jid.MatchesDomain|jid.MatchesResource)) + require.True(t, j1.MatchesWithOptions(j2, jid.MatchesNode|jid.MatchesDomain)) + require.True(t, j1.MatchesWithOptions(j3, jid.MatchesDomain)) + require.True(t, j1.MatchesWithOptions(j4, jid.MatchesDomain|jid.MatchesResource)) - require.False(t, j1.Matches(j2, jid.MatchesNode|jid.MatchesDomain|jid.MatchesResource)) - require.False(t, j6.Matches(j2, jid.MatchesNode|jid.MatchesDomain)) - require.False(t, j6.Matches(j3, jid.MatchesDomain)) - require.False(t, j6.Matches(j4, jid.MatchesDomain|jid.MatchesResource)) + require.False(t, j1.MatchesWithOptions(j2, jid.MatchesNode|jid.MatchesDomain|jid.MatchesResource)) + require.False(t, j6.MatchesWithOptions(j2, jid.MatchesNode|jid.MatchesDomain)) + require.False(t, j6.MatchesWithOptions(j3, jid.MatchesDomain)) + require.False(t, j6.MatchesWithOptions(j4, jid.MatchesDomain|jid.MatchesResource)) } func TestBadPrep(t *testing.T) { diff --git a/xmpp/parser.go b/xmpp/parser.go index 9b112df56..0b0309de1 100644 --- a/xmpp/parser.go +++ b/xmpp/parser.go @@ -15,8 +15,7 @@ import ( const rootElementIndex = -1 const ( - streamName = "stream" - framedStreamNamespace = "urn:ietf:params:xml:ns:xmpp-framing" + streamName = "stream" ) // ParsingMode defines the way in which special parsed element @@ -29,9 +28,6 @@ const ( // SocketStream treats incoming elements as provided from a socket transport. SocketStream - - // WebSocketStream treats incoming elements as provided from a websocket transport. - WebSocketStream ) // ErrTooLargeStanza is returned by ReadElement when the size of @@ -111,9 +107,7 @@ func (p *Parser) ParseElement() (XElement, error) { done: p.lastOffset = p.dec.InputOffset() ret := p.nextElement - if p.mode == WebSocketStream && ret.Name() == "close" && ret.Namespace() == framedStreamNamespace { - return nil, ErrStreamClosedByPeer - } + p.nextElement = nil return ret, nil } diff --git a/xmpp/parser_test.go b/xmpp/parser_test.go index feedf1c72..019f4d0bf 100644 --- a/xmpp/parser_test.go +++ b/xmpp/parser_test.go @@ -56,11 +56,6 @@ func TestParser_Close(t *testing.T) { p := xmpp.NewParser(strings.NewReader(src), xmpp.SocketStream, 0) _, err := p.ParseElement() require.Equal(t, xmpp.ErrStreamClosedByPeer, err) - - src = `\n` - p = xmpp.NewParser(strings.NewReader(src), xmpp.WebSocketStream, 0) - _, err = p.ParseElement() - require.Equal(t, xmpp.ErrStreamClosedByPeer, err) } func TestParser_ParseSeveralElements(t *testing.T) { diff --git a/xmpp/presence.go b/xmpp/presence.go index 160a1f2d6..b446d130d 100644 --- a/xmpp/presence.go +++ b/xmpp/presence.go @@ -14,6 +14,8 @@ import ( "github.com/ortuman/jackal/xmpp/jid" ) +const capabilitiesNamespace = "http://jabber.org/protocol/caps" + const ( // AvailableType represents an 'available' Presence type. AvailableType = "" @@ -37,6 +39,13 @@ const ( ProbeType = "probe" ) +// Capabilities represents presence entity capabilities +type Capabilities struct { + Node string + Hash string + Ver string +} + // ShowState represents Presence show state. type ShowState int @@ -168,6 +177,20 @@ func (p *Presence) Priority() int8 { return p.priority } +// Capabilities returns presence stanza capabilities element +func (p *Presence) Capabilities() *Capabilities { + c := p.Elements().ChildNamespace("c", capabilitiesNamespace) + if c == nil { + return nil + } + attribs := c.Attributes() + return &Capabilities{ + Node: attribs.Get("node"), + Hash: attribs.Get("hash"), + Ver: attribs.Get("ver"), + } +} + func isPresenceType(presenceType string) bool { switch presenceType { case ErrorType, AvailableType, UnavailableType, SubscribeType, diff --git a/xmpp/presence_test.go b/xmpp/presence_test.go index b2dc921b3..0f39a537a 100644 --- a/xmpp/presence_test.go +++ b/xmpp/presence_test.go @@ -117,6 +117,21 @@ func TestPresenceBuild(t *testing.T) { presence, err = xmpp.NewPresenceFromElement(elem, j, j) require.Nil(t, err) require.Equal(t, "Readable text", presence.Status()) + + elem.ClearElements() + c := xmpp.NewElementNamespace("c", "http://jabber.org/protocol/caps") + c.SetAttribute("hash", "sha-1") + c.SetAttribute("node", "http://code.google.com/p/exodus") + c.SetAttribute("ver", "QgayPKawpkPSDYmwT/WM94uAlu0=") + elem.AppendElement(c) + presence, err = xmpp.NewPresenceFromElement(elem, j, j) + require.Nil(t, err) + + caps := presence.Capabilities() + require.NotNil(t, caps) + require.Equal(t, "sha-1", caps.Hash) + require.Equal(t, "http://code.google.com/p/exodus", caps.Node) + require.Equal(t, "QgayPKawpkPSDYmwT/WM94uAlu0=", caps.Ver) } func TestPresenceType(t *testing.T) { diff --git a/xmpp/xmpp.go b/xmpp/xmpp.go index dac74f480..82113cc6d 100644 --- a/xmpp/xmpp.go +++ b/xmpp/xmpp.go @@ -10,7 +10,7 @@ import ( "fmt" "io" - "github.com/ortuman/jackal/pool" + "github.com/ortuman/jackal/util/pool" "github.com/ortuman/jackal/xmpp/jid" ) @@ -42,7 +42,7 @@ type XElement interface { IsError() bool Error() XElement - ToXML(w io.Writer, includeClosing bool) + ToXML(w io.Writer, includeClosing bool) error ToBytes(buf *bytes.Buffer) error }