メモ:#[extendr] の実装を追ってみる

こういう関数があるとする。

#[extendr(use_try_from = true)]
fn string(input: &str) -> String {
    input.to_string()
}

cargo expand すると以下が出てくる。

fn string(input: &str) -> String {
    input.to_string()
}

#[no_mangle]
#[allow(non_snake_case)]
pub extern "C" fn wrap__string(input: extendr_api::SEXP) -> extendr_api::SEXP {
    unsafe {
        use extendr_api::robj::*;
        let _input_robj = extendr_api::new_owned(input);
        extendr_api::handle_panic("string panicked.\0", || {
            extendr_api::Robj::from(string(extendr_api::unwrap_or_throw_error(
                _input_robj
                    .try_into()
                    .map_err(|e| extendr_api::Error::from(e)),
            )))
            .get()
        })
    }
}

#[allow(non_snake_case)]
fn meta__string(metadata: &mut Vec<extendr_api::metadata::Func>) {
    let mut args = <[_]>::into_vec(
        #[rustc_box]
        ::alloc::boxed::Box::new([extendr_api::metadata::Arg {
            name: "input",
            arg_type: "str",
            default: None,
        }]),
    );
    metadata.push(extendr_api::metadata::Func {
        doc: " Return a dynamic string.\n\n @export",
        rust_name: "string",
        r_name: "string",
        mod_name: "string",
        args: args,
        return_type: "String",
        func_ptr: wrap__string as *const u8,
        hidden: false,
    })
}

meta__string()はnative routine registration用なのでいったん置いておいて、 wrap__string() を見る。 これは、実際の実装であるstring()を名前通りラップしていて、具体的には、

  • 入力のSEXPstring()が受け取る型に変換
  • 出力をSEXPに変換
  • panic!()でunwindが起こるとまずいので、起こらないようにする。なぜまずいかというと、FFI越しのunwindはundefined behaviorなので(参考:FFI - The Rustonomicon

ということをやっている。

入力のSEXPstring()が受け取る型に変換

new_owned()

まずはこの行。inputextendr_api::SEXPという型だが、これはlibR-sysの定義そのまま。

        let _input_robj = extendr_api::new_owned(input);

new_owned()の実装はこのへん

与えられたSEXPを protect して、それをRobjにして返す、ということをしている。 ややこしいが、この「protect」は、単にRf_protect()を実行するということではなく、extendr側で管理するオブジェクトのリストにその対象を追加する、ということになるらしい。

#[doc(hidden)]
pub unsafe fn new_owned(sexp: SEXP) -> Robj {
    single_threaded(|| {
        ownership::protect(sexp);
        Robj { inner: sexp }
    })
}

single_threaded()の実装はこのへん。 あんまりちゃんと理解できてないけど、同時に呼び出されている場合には順番を待つ、ということらしい。 ちなみに、cpp11の場合もdouble linked listを使って似たようなことを実装しているらしい

use std::sync::atomic::{AtomicU32, Ordering};

static OWNER_THREAD: AtomicU32 = AtomicU32::new(0);
static NEXT_THREAD_ID: AtomicU32 = AtomicU32::new(1);

thread_local! {
    static THREAD_ID: u32 = NEXT_THREAD_ID.fetch_add(1, Ordering::SeqCst);
}

pub fn this_thread_id() -> u32 {
    THREAD_ID.with(|&v| v)
}

pub fn single_threaded<F, R>(f: F) -> R
where
    F: FnOnce() -> R,
{
    let id = this_thread_id();
    let old_id = OWNER_THREAD.load(Ordering::Acquire);

    if old_id != id {
        // wait for OWNER_THREAD to become 0 and put us as the owner.
        while OWNER_THREAD
            .compare_exchange(0, id, Ordering::Acquire, Ordering::Relaxed)
            .is_err()
        {
            std::thread::sleep(std::time::Duration::from_millis(1));
        }
    }

    let res = f();

    if old_id != id {
        // release the lock and signal waiting threads.
        OWNER_THREAD.store(0, Ordering::Release);
    }

    res
}

ownership::protect()の実装はこのへん。 実際の処理はOwnership::protect()にある。

lazy_static! {
    static ref OWNERSHIP: Mutex<Ownership> = Mutex::new(Ownership::new());
}

pub(crate) unsafe fn protect(sexp: SEXP) {
    let mut own = OWNERSHIP.lock().expect("protect failed");
    own.protect(sexp);
}

OwnershipはRのreference countを表すもので、定義はこのへんHashMapを使ってpreserveするオブジェクトを管理している。

struct Object {
    refcount: usize,
    index: usize,
}

// A reference counted object with an index in the preservation vector.
struct Ownership {
    // A growable vector containing all owned objects.
    preservation: usize,

    // An incrementing count of objects through the vector.
    cur_index: usize,

    // The size of the vector.
    max_index: usize,

    // A hash map from SEXP address to object.
    objects: HashMap<usize, Object>,
}

で、Ownership::protect()このへん。 ここはちょっと長いので少しずつ見ていく。 まず、Rf_protect()を呼んでR側に勝手に消されないようにする。

    unsafe fn protect(&mut self, sexp: SEXP) {
        Rf_protect(sexp);

もし管理するオブジェクト数が限界に達していたら GC を行う。

        if self.cur_index == self.max_index {
            self.garbage_collect();
        }

対象のオブジェクトがすでにprotect対象になっているかどうかを調べ、あれば(Entry::Occupied)そのrefcountをインクリメントして、なければ(Entry::Vacant)追加する。

        let sexp_usize = sexp as usize;
        let Ownership {
            ref mut preservation,
            ref mut cur_index,
            ref mut max_index,
            ref mut objects,
        } = *self;

        let mut entry = objects.entry(sexp_usize);
        let preservation_sexp = *preservation as SEXP;
        match entry {
            Entry::Occupied(ref mut occupied) => {
                if occupied.get().refcount == 0 {
                    // Address re-used - re-set the sexp.
                    SET_VECTOR_ELT(preservation_sexp, occupied.get().index as R_xlen_t, sexp);
                }
                occupied.get_mut().refcount += 1;
            }
            Entry::Vacant(vacant) => {
                let index = *cur_index;
                SET_VECTOR_ELT(preservation_sexp, index as R_xlen_t, sexp);
                *cur_index += 1;
                assert!(index != *max_index);
                let refcount = 1;
                vacant.insert(Object { refcount, index });
            }
        }

最後にRf_unprotect()を呼んで終わり。これでR側からかってに消されてしまわないの...?と思って悩んだけど、 SET_VECTOR_ELT()でリストに突っ込んでいるので、そこから参照されている限り消えない、ということだった。

        Rf_unprotect(1);

try_into()

wrap__string()に戻って、次はこのあたり。string()の引数として渡されているうち、unwrap_or_throw_error()は無視して、try_into()を見ていく。

            extendr_api::Robj::from(string(extendr_api::unwrap_or_throw_error(
                _input_robj
                    .try_into()
                    .map_err(|e| extendr_api::Error::from(e)),
            )))

実装はこのあたり

impl TryFrom<&Robj> for &str {
    type Error = Error;

    /// Convert a scalar STRSXP object into a string slice.
    /// NAs are not allowed.
    fn try_from(robj: &Robj) -> Result<Self> {
        if robj.is_na() {
            return Err(Error::MustNotBeNA(robj.clone()));
        }
        match robj.len() {
            0 => Err(Error::ExpectedNonZeroLength(robj.clone())),
            1 => {
                if let Some(s) = robj.as_str() {
                    Ok(s)
                } else {
                    Err(Error::ExpectedString(robj.clone()))
                }
            }
            _ => Err(Error::ExpectedScalar(robj.clone())),
        }
    }
}

as_str()このへん

    pub fn as_str<'a>(&self) -> Option<&'a str> {
        unsafe {
            match self.sexptype() {
                STRSXP => {
                    if self.len() != 1 {
                        None
                    } else {
                        Some(to_str(R_CHAR(STRING_ELT(self.get(), 0)) as *const u8))
                    }
                }
                // CHARSXP => Some(to_str(R_CHAR(self.get()) as *const u8)),
                // SYMSXP => Some(to_str(R_CHAR(PRINTNAME(self.get())) as *const u8)),
                _ => None,
            }
        }
    }
// Internal utf8 to str conversion.
// Lets not worry about non-ascii/unicode strings for now (or ever).
pub(crate) unsafe fn to_str<'a>(ptr: *const u8) -> &'a str {
    let mut len = 0;
    loop {
        if *ptr.offset(len) == 0 {
            break;
        }
        len += 1;
    }
    let slice = std::slice::from_raw_parts(ptr, len as usize);
    std::str::from_utf8_unchecked(slice)
}

出力をSEXPに変換

SEXPからの変換はtry_from()だったけど、SEXPへの変換はfrom()でいける(失敗しない)

macro_rules! impl_str_tvv {
    ($t: ty) => {
        impl ToVectorValue for $t {
            fn sexptype() -> SEXPTYPE {
                STRSXP
            }


            fn to_sexp(&self) -> SEXP
            where
                Self: Sized,
            {
                str_to_character(self.as_ref())
            }
        }

// ...snip...

}


impl_str_tvv! {&str}
impl_str_tvv! {String}

実際の変換はここ。たぶんRf_mkCharLen()じゃなくてRf_mkCharLenCE()を使うべき...(どうせR 4.1以前はサポートしないのでRf_mkCharLen()でもUTF-8になる、ということでだいたいは問題ないけど)。

pub(crate) fn str_to_character(s: &str) -> SEXP {
    unsafe {
        if s.is_na() {
            R_NaString
        } else {
            single_threaded(|| Rf_mkCharLen(s.as_ptr() as *const raw::c_char, s.len() as i32))
        }
    }
}

Rf_mkCharLen()は文字列をコピーするので、Rust側の文字列がこのあと消えようが特に問題はない(逆に言うと、残っていようが1回はコピーされるということ)。

https://github.com/wch/r-source/blob/c747e3e4a78322a42c5783dd340f84f4963e380e/src/main/envir.c#L4258

unwindが起こらないようにする

handle_panic()このへん

pub fn handle_panic<F, R>(err_str: &str, f: F) -> R
where
    F: FnOnce() -> R,
    F: std::panic::UnwindSafe,
{
    match std::panic::catch_unwind(f) {
        Ok(res) => res,
        Err(_) => {
            unsafe {
                libR_sys::Rf_error(err_str.as_ptr() as *const std::os::raw::c_char);
            }
            unreachable!("handle_panic unreachable")
        }
    }
}

catch_unwind()はこれ

https://doc.rust-lang.org/std/panic/fn.catch_unwind.html