diff --git a/README.md b/README.md index b6fddb6..5c41e6e 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,17 @@ Connecting with a specific Hive version (0.12) and using the `:http` transport: connection.fetch('SHOW TABLES') end -We have not tested the SASL connection, as we don't run SASL; pull requests and testing are welcomed. +Connecting with SASL and Kerberos v5: + + RBHive.tcli_connect('hive.hadoop.forward.co.uk', 10_000, { + :transport => :sasl, + :sasl_params => { + :mechanism => 'GSSAPI', + :remote_host => 'example.com', + :remote_principal => 'hive/example.com@EXAMPLE.COM' + ) do |connection| + connection.fetch("show tables") + end #### Hiveserver2 protocol versions diff --git a/lib/rbhive/t_c_l_i_connection.rb b/lib/rbhive/t_c_l_i_connection.rb index e1839e9..da6bb31 100644 --- a/lib/rbhive/t_c_l_i_connection.rb +++ b/lib/rbhive/t_c_l_i_connection.rb @@ -375,7 +375,9 @@ def method_missing(meth, *args) private def prepare_open_session(client_protocol) - req = ::Hive2::Thrift::TOpenSessionReq.new( @options[:sasl_params].nil? ? [] : @options[:sasl_params] ) + req = ::Hive2::Thrift::TOpenSessionReq.new( @options[:sasl_params].nil? ? [] : { + :username => @options[:sasl_params][:username], + :password => @options[:sasl_params][:password]}) req.client_protocol = client_protocol req end diff --git a/lib/thrift/sasl_client_transport.rb b/lib/thrift/sasl_client_transport.rb index 7361a3a..d58ddcb 100644 --- a/lib/thrift/sasl_client_transport.rb +++ b/lib/thrift/sasl_client_transport.rb @@ -1,10 +1,9 @@ -module Thrift - class SaslClientTransport < BufferedTransport - attr_reader :challenge +require 'gssapi' +module Thrift + class SaslClientTransport < FramedTransport STATUS_BYTES = 1 PAYLOAD_LENGTH_BYTES = 4 - AUTH_MECHANISM = 'PLAIN' NEGOTIATION_STATUS = { START: 0x01, OK: 0x02, @@ -15,76 +14,99 @@ class SaslClientTransport < BufferedTransport def initialize(transport, sasl_params={}) super(transport) - @challenge = nil + @sasl_complete = nil @sasl_username = sasl_params.fetch(:username, 'anonymous') @sasl_password = sasl_params.fetch(:password, 'anonymous') - end + @sasl_mechanism = sasl_params.fetch(:mechanism, 'PLAIN') - def read(sz) - len, = @transport.read(PAYLOAD_LENGTH_BYTES).unpack('l>') if @rbuf.nil? - sz = len if len && sz > len - @index += sz - ret = @rbuf.slice(@index - sz, sz) || Bytes.empty_byte_buffer - if ret.length == 0 - @rbuf = @transport.read(len) rescue Bytes.empty_byte_buffer - @index = sz - ret = @rbuf.slice(0, sz) || Bytes.empty_byte_buffer + unless ['PLAIN', 'GSSAPI'].include? @sasl_mechanism + raise "Unknown SASL mechanism: #{@sasl_mechanism}" end - ret - end - def read_byte - reset_buffer! if @index >= @rbuf.size - @index += 1 - Bytes.get_string_byte(@rbuf, @index - 1) + if @sasl_mechanism == 'GSSAPI' + @sasl_remote_principal = sasl_params[:remote_principal] + @sasl_remote_host = sasl_params[:remote_host] + @gsscli = GSSAPI::Simple.new(@sasl_remote_host, @sasl_remote_principal) + end end - def read_into_buffer(buffer, size) - i = 0 - while i < size - reset_buffer! if @index >= @rbuf.size - byte = Bytes.get_string_byte(@rbuf, @index) - Bytes.set_string_byte(buffer, i, byte) - @index += 1 - i += 1 + def open + super + + case @sasl_mechanism + when 'PLAIN' + handshake_plain! + when 'GSSAPI' + handshake_gssapi! end - i end - def write(buf) - initiate_hand_shake if @challenge.nil? - header = [buf.length].pack('l>') - @wbuf << (header + Bytes.force_binary_encoding(buf)) - end + private - protected + def handshake_plain! + token = "[#{@sasl_mechanism}]\u0000#{@sasl_username}\u0000#{@sasl_password}" + write_handshake_message(NEGOTIATION_STATUS[:START], @sasl_mechanism) + write_handshake_message(NEGOTIATION_STATUS[:OK], token) - def initiate_hand_shake - header = [NEGOTIATION_STATUS[:START], AUTH_MECHANISM.length].pack('cl>') - @transport.write header + AUTH_MECHANISM - message = "[#{AUTH_MECHANISM}]\u0000#{@sasl_username}\u0000#{@sasl_password}" - header = [NEGOTIATION_STATUS[:OK], message.length].pack('cl>') - @transport.write header + message - status, len = @transport.read(STATUS_BYTES + PAYLOAD_LENGTH_BYTES).unpack('cl>') + status, msg = read_handshake_message case status - when NEGOTIATION_STATUS[:BAD], NEGOTIATION_STATUS[:ERROR] - raise @transport.to_io.read(len) when NEGOTIATION_STATUS[:COMPLETE] - @challenge = @transport.to_io.read len + @sasl_complete = true when NEGOTIATION_STATUS[:OK] raise "Failed to complete challenge exchange: only NONE supported currently" end end - private + def handshake_gssapi! + token = @gsscli.init_context + write_handshake_message(NEGOTIATION_STATUS[:START], @sasl_mechanism) + write_handshake_message(NEGOTIATION_STATUS[:OK], token) - def reset_buffer! - len, = @transport.read(PAYLOAD_LENGTH_BYTES).unpack('l>') - @rbuf = @transport.read(len) - while @rbuf.size < len - @rbuf << @transport.read(len - @rbuf.size) + status, msg = read_handshake_message + case status + when NEGOTIATION_STATUS[:COMPLETE] + raise "Unexpected COMPLETE from server" + when NEGOTIATION_STATUS[:OK] + unless @gsscli.init_context(msg) + raise "GSSAPI: challenge provided by server could not be verified" + end + + write_handshake_message(NEGOTIATION_STATUS[:OK], "") + + status, msg = read_handshake_message + case status + when NEGOTIATION_STATUS[:COMPLETE] + raise "Unexpected COMPLETE from server" + when NEGOTIATION_STATUS[:OK] + unwrapped = @gsscli.unwrap_message(msg) + rewrapped = @gsscli.wrap_message(unwrapped) + + write_handshake_message(NEGOTIATION_STATUS[:COMPLETE], rewrapped) + + status, msg = read_handshake_message + case status + when NEGOTIATION_STATUS[:COMPLETE] + @sasl_complete = true + when NEGOTIATION_STATUS[:OK] + raise "Failed to complete GSS challenge exchange" + end + end + end + end + + def read_handshake_message + status, len = @transport.read(STATUS_BYTES + PAYLOAD_LENGTH_BYTES).unpack('cl>') + body = @transport.to_io.read(len) + if [NEGOTIATION_STATUS[:BAD], NEGOTIATION_STATUS[:ERROR]].include?(status) + raise "Exception from server: #{body}" end - @index = 0 + + [status, body] + end + + def write_handshake_message(status, message) + header = [status, message.length].pack('cl>') + @transport.write(header + message) end end @@ -93,5 +115,4 @@ def get_transport(transport) return SaslClientTransport.new(transport) end end - end diff --git a/rbhive.gemspec b/rbhive.gemspec index 8ef2cc2..ac1413a 100644 --- a/rbhive.gemspec +++ b/rbhive.gemspec @@ -19,6 +19,7 @@ Gem::Specification.new do |spec| spec.require_paths = ['lib'] spec.add_dependency('thrift', '~> 0.9') + spec.add_dependency('gssapi', '~> 1.2') spec.add_dependency('json') spec.add_development_dependency 'rake'