Java FFI: make `MModule::java_file` public
[nit.git] / src / ffi / java.nit
index dd32712..0dd65be 100644 (file)
@@ -193,8 +193,8 @@ end
 redef class MModule
        private var callbacks_used_from_java = new ForeignCallbackSet
 
-       # Pure java class source file
-       private var java_file: nullable JavaClassTemplate = null
+       # Java source file extracted from user FFI code with generated structure
+       var java_file: nullable JavaClassTemplate = null
 
        # Set up the templates of the Java implementation class
        private fun ensure_java_files
@@ -220,6 +220,9 @@ redef class MModule
                for cb in callbacks do
                        jni_methods.add_all(cb.jni_methods_declaration(self))
                end
+               for cb in callbacks_used_from_java.types do
+                       jni_methods.add_all(cb.jni_methods_declaration(self))
+               end
 
                var cf = new CFunction("void nit_ffi_with_java_register_natives(JNIEnv* env, jclass jclazz)")
                cf.exprs.add """
@@ -470,6 +473,38 @@ redef class MType
        # Used by `JavaLanguage::compile_extern_method` when calling JNI's `CallStatic*Method`.
        # This strategy is used by JNI to type the return of callbacks to Java.
        private fun jni_signature_alt: String do return "Int"
+
+       redef fun compile_callback_to_java(mmodule, mainmodule, ccu)
+       do
+               var java_file = mmodule.java_file
+               if java_file == null then return
+
+               for variation in ["incr", "decr"] do
+                       var friendly_name = "{mangled_cname}_{variation}_ref"
+
+                       # C
+                       var csignature = "void {mmodule.impl_java_class_name}_{friendly_name}(JNIEnv *env, jclass clazz, jint object)"
+                       var cf = new CFunction("JNIEXPORT {csignature}")
+                       cf.exprs.add "\tnitni_global_ref_{variation}((void*)(long)object);"
+                       ccu.add_non_static_local_function cf
+
+                       # Java
+                       java_file.class_content.add "private native static void {friendly_name}(int object);\n"
+               end
+       end
+
+       redef fun jni_methods_declaration(from_mmodule)
+       do
+               var arr = new Array[String]
+               for variation in ["incr", "decr"] do
+                       var friendly_name = "{mangled_cname}_{variation}_ref"
+                       var jni_format = "(I)V"
+                       var cname = "{from_mmodule.impl_java_class_name}_{friendly_name}"
+                       arr.add """{"{{{friendly_name}}}", "{{{jni_format}}}", {{{cname}}}}"""
+               end
+
+               return arr
+       end
 end
 
 redef class MClassType
@@ -483,6 +518,11 @@ redef class MClassType
                if mclass.name == "Int" then return "long"
                if mclass.name == "Float" then return "double"
                if mclass.name == "Byte" then return "byte"
+               if mclass.name == "Int8" then return "byte"
+               if mclass.name == "Int16" then return "short"
+               if mclass.name == "UInt16" then return "short"
+               if mclass.name == "Int32" then return "int"
+               if mclass.name == "UInt32" then return "int"
                return super
        end
 
@@ -495,6 +535,11 @@ redef class MClassType
                if mclass.name == "Int" then return "jlong"
                if mclass.name == "Float" then return "jdouble"
                if mclass.name == "Byte" then return "jbyte"
+               if mclass.name == "Int8" then return "jbyte"
+               if mclass.name == "Int16" then return "jshort"
+               if mclass.name == "UInt16" then return "jshort"
+               if mclass.name == "Int32" then return "jint"
+               if mclass.name == "UInt32" then return "jint"
                return super
        end
 
@@ -555,6 +600,11 @@ redef class MClassType
                if mclass.name == "Int" then return "J"
                if mclass.name == "Float" then return "D"
                if mclass.name == "Byte" then return "B"
+               if mclass.name == "Int8" then return "B"
+               if mclass.name == "Int16" then return "S"
+               if mclass.name == "UInt16" then return "S"
+               if mclass.name == "Int32" then return "I"
+               if mclass.name == "UInt32" then return "I"
                return super
        end
 
@@ -568,6 +618,11 @@ redef class MClassType
                if mclass.name == "Int" then return "Long"
                if mclass.name == "Float" then return "Double"
                if mclass.name == "Byte" then return "Byte"
+               if mclass.name == "Int8" then return "Byte"
+               if mclass.name == "Int16" then return "Short"
+               if mclass.name == "UInt16" then return "Short"
+               if mclass.name == "Int32" then return "Int"
+               if mclass.name == "UInt32" then return "Int"
                return super
        end
 end